|
26 | 26 | import pandas as pd |
27 | 27 | import numpy as np |
28 | 28 | from packaging import version |
| 29 | +from iblutil.util import Bunch |
29 | 30 | from iblutil.io import parquet |
30 | 31 | from iblutil.io.hashfile import md5 |
31 | 32 |
|
32 | 33 | from one.alf.spec import QC, is_uuid_string |
33 | 34 | from one.alf.io import iter_sessions |
34 | 35 | from one.alf.path import session_path_parts, get_alf_path |
35 | 36 |
|
36 | | -__all__ = ['make_parquet_db', 'patch_cache', 'remove_missing_datasets', |
37 | | - 'remove_cache_table_files', 'EMPTY_DATASETS_FRAME', 'EMPTY_SESSIONS_FRAME', 'QC_TYPE'] |
| 37 | +__all__ = [ |
| 38 | + 'make_parquet_db', 'patch_tables', 'merge_tables', 'QC_TYPE', 'remove_table_files', |
| 39 | + 'remove_missing_datasets', 'load_tables', 'EMPTY_DATASETS_FRAME', 'EMPTY_SESSIONS_FRAME'] |
38 | 40 | _logger = logging.getLogger(__name__) |
39 | 41 |
|
40 | 42 | # ------------------------------------------------------------------------------------------------- |
@@ -356,6 +358,146 @@ def cast_index_object(df: pd.DataFrame, dtype: type = uuid.UUID) -> pd.Index: |
356 | 358 | return df |
357 | 359 |
|
358 | 360 |
|
| 361 | +def load_tables(tables_dir, glob_pattern='*.pqt'): |
| 362 | + """Load parquet cache files from a local directory. |
| 363 | +
|
| 364 | + Parameters |
| 365 | + ---------- |
| 366 | + tables_dir : str, pathlib.Path |
| 367 | + The directory location of the parquet files. |
| 368 | + glob_pattern : str |
| 369 | + A glob pattern to match the cache files. |
| 370 | +
|
| 371 | +
|
| 372 | + Returns |
| 373 | + ------- |
| 374 | + Bunch |
| 375 | + A Bunch object containing the loaded cache tables and associated metadata. |
| 376 | +
|
| 377 | + """ |
| 378 | + meta = { |
| 379 | + 'expired': False, |
| 380 | + 'created_time': None, |
| 381 | + 'loaded_time': None, |
| 382 | + 'modified_time': None, |
| 383 | + 'saved_time': None, |
| 384 | + 'raw': {} |
| 385 | + } |
| 386 | + caches = Bunch({ |
| 387 | + 'datasets': EMPTY_DATASETS_FRAME.copy(), |
| 388 | + 'sessions': EMPTY_SESSIONS_FRAME.copy(), |
| 389 | + '_meta': meta}) |
| 390 | + INDEX_KEY = '.?id' |
| 391 | + for cache_file in Path(tables_dir).glob(glob_pattern): |
| 392 | + table = cache_file.stem |
| 393 | + # we need to keep this part fast enough for transient objects |
| 394 | + cache, meta['raw'][table] = parquet.load(cache_file) |
| 395 | + if 'date_created' not in meta['raw'][table]: |
| 396 | + _logger.warning(f"{cache_file} does not appear to be a valid table. Skipping") |
| 397 | + continue |
| 398 | + meta['loaded_time'] = datetime.datetime.now() |
| 399 | + |
| 400 | + # Set the appropriate index if none already set |
| 401 | + if isinstance(cache.index, pd.RangeIndex): |
| 402 | + idx_columns = sorted(cache.filter(regex=INDEX_KEY).columns) |
| 403 | + if len(idx_columns) == 0: |
| 404 | + raise KeyError('Failed to set index') |
| 405 | + cache.set_index(idx_columns, inplace=True) |
| 406 | + |
| 407 | + # Patch older tables |
| 408 | + cache = patch_tables(cache, meta['raw'][table].get('min_api_version'), table) |
| 409 | + |
| 410 | + # Cast indices to UUID |
| 411 | + cache = cast_index_object(cache, uuid.UUID) |
| 412 | + |
| 413 | + # Check sorted |
| 414 | + # Sorting makes MultiIndex indexing O(N) -> O(1) |
| 415 | + if not cache.index.is_monotonic_increasing: |
| 416 | + cache.sort_index(inplace=True) |
| 417 | + |
| 418 | + caches[table] = cache |
| 419 | + |
| 420 | + created = [datetime.datetime.fromisoformat(x['date_created']) |
| 421 | + for x in meta['raw'].values() if 'date_created' in x] |
| 422 | + if created: |
| 423 | + meta['created_time'] = min(created) |
| 424 | + return caches |
| 425 | + |
| 426 | + |
| 427 | +def merge_tables(cache, strict=False, **kwargs): |
| 428 | + """Update the cache tables with new records. |
| 429 | +
|
| 430 | + Parameters |
| 431 | + ---------- |
| 432 | + dict |
| 433 | + A map of cache tables to update. |
| 434 | + strict : bool |
| 435 | + If not True, the columns don't need to match. Extra columns in input tables are |
| 436 | + dropped and missing columns are added and filled with np.nan. |
| 437 | + kwargs |
| 438 | + pandas.DataFrame or pandas.Series to insert/update for each table. |
| 439 | +
|
| 440 | + Returns |
| 441 | + ------- |
| 442 | + datetime.datetime: |
| 443 | + A timestamp of when the cache was updated. |
| 444 | +
|
| 445 | + Example |
| 446 | + ------- |
| 447 | + >>> session, datasets = ses2records(self.get_details(eid, full=True)) |
| 448 | + ... self._update_cache_from_records(sessions=session, datasets=datasets) |
| 449 | +
|
| 450 | + Raises |
| 451 | + ------ |
| 452 | + AssertionError |
| 453 | + When strict is True the input columns must exactly match those oo the cache table, |
| 454 | + including the order. |
| 455 | + KeyError |
| 456 | + One or more of the keyword arguments does not match a table in cache. |
| 457 | +
|
| 458 | + """ |
| 459 | + updated = None |
| 460 | + for table, records in kwargs.items(): |
| 461 | + if records is None or records.empty: |
| 462 | + continue |
| 463 | + if table not in cache: |
| 464 | + raise KeyError(f'Table "{table}" not in cache') |
| 465 | + if isinstance(records, pd.Series): |
| 466 | + records = pd.DataFrame([records]) |
| 467 | + records.index.set_names(cache[table].index.names, inplace=True) |
| 468 | + # Drop duplicate indices |
| 469 | + records = records[~records.index.duplicated(keep='first')] |
| 470 | + if not strict: |
| 471 | + # Deal with case where there are extra columns in the cache |
| 472 | + extra_columns = list(set(cache[table].columns) - set(records.columns)) |
| 473 | + # Convert these columns to nullable, if required |
| 474 | + cache_columns = cache[table][extra_columns] |
| 475 | + cache[table][extra_columns] = cache_columns.convert_dtypes() |
| 476 | + column_ids = map(list(cache[table].columns).index, extra_columns) |
| 477 | + for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]): |
| 478 | + dtype = cache[table][col].dtype |
| 479 | + nan = getattr(dtype, 'na_value', np.nan) |
| 480 | + val = records.get('exists', True) if col.startswith('exists_') else nan |
| 481 | + records.insert(n, col, val) |
| 482 | + # Drop any extra columns in the records that aren't in cache table |
| 483 | + to_drop = set(records.columns) - set(cache[table].columns) |
| 484 | + records = records.drop(to_drop, axis=1) |
| 485 | + records = records.reindex(columns=cache[table].columns) |
| 486 | + assert set(cache[table].columns) == set(records.columns) |
| 487 | + records = records.astype(cache[table].dtypes) |
| 488 | + # Update existing rows |
| 489 | + to_update = records.index.isin(cache[table].index) |
| 490 | + cache[table].loc[records.index[to_update], :] = records[to_update] |
| 491 | + # Assign new rows |
| 492 | + to_assign = records[~to_update] |
| 493 | + frames = [cache[table], to_assign] |
| 494 | + # Concatenate and sort |
| 495 | + cache[table] = pd.concat(frames).sort_index() |
| 496 | + updated = datetime.datetime.now() |
| 497 | + cache['_meta']['modified_time'] = updated |
| 498 | + return updated |
| 499 | + |
| 500 | + |
359 | 501 | def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True, dry=True): |
360 | 502 | """Remove dataset files and session folders that are not in the provided cache. |
361 | 503 |
|
@@ -383,7 +525,7 @@ def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True, |
383 | 525 | tables = {} |
384 | 526 | for name in ('datasets', 'sessions'): |
385 | 527 | table, m = parquet.load(cache_dir / f'{name}.pqt') |
386 | | - tables[name] = patch_cache(table, m.get('min_api_version'), name) |
| 528 | + tables[name] = patch_tables(table, m.get('min_api_version'), name) |
387 | 529 |
|
388 | 530 | INDEX_KEY = '.?id' |
389 | 531 | for name in tables: |
@@ -432,7 +574,7 @@ def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True, |
432 | 574 | return sorted(to_delete) |
433 | 575 |
|
434 | 576 |
|
435 | | -def remove_cache_table_files(folder, tables=('sessions', 'datasets')): |
| 577 | +def remove_table_files(folder, tables=('sessions', 'datasets')): |
436 | 578 | """Delete cache tables on disk. |
437 | 579 |
|
438 | 580 | Parameters |
@@ -482,7 +624,7 @@ def _cache_int2str(table: pd.DataFrame) -> pd.DataFrame: |
482 | 624 | return table |
483 | 625 |
|
484 | 626 |
|
485 | | -def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame: |
| 627 | +def patch_tables(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame: |
486 | 628 | """Reformat older cache tables to comply with this version of ONE. |
487 | 629 |
|
488 | 630 | Currently this function will 1. convert integer UUIDs to string UUIDs; 2. rename the 'project' |
|
0 commit comments