Skip to content

Commit 05c3ce9

Browse files
authored
Ignore skims (#54)
* allow ignoring skims * update min numba * add tests and doc for ignore * shared memory pre-init * non-dask reload of data * improved logging * log bytes loaded so far * fix bug * docs and tests
1 parent fb31a68 commit 05c3ce9

File tree

6 files changed

+362
-109
lines changed

6 files changed

+362
-109
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ numpy >= 1.19
44
pandas >= 1.2
55
pyarrow >= 3.0.0
66
xarray >= 0.20.0
7-
numba >= 0.54
7+
numba >= 0.57
88
numexpr
99
filelock
1010
sphinx-autosummary-accessors

envs/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- xarray
1111
- dask
1212
- networkx
13-
- numba>=0.54
13+
- numba>=0.57
1414
- numexpr
1515
- sparse
1616
- filelock

sharrow/dataset.py

Lines changed: 223 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import hashlib
66
import logging
77
import re
8-
from collections.abc import Hashable, Mapping, Sequence
9-
from typing import Any
8+
import time
9+
from collections.abc import Hashable, Iterable, Mapping, Sequence
10+
from typing import TYPE_CHECKING, Any
1011

1112
import numpy as np
1213
import pandas as pd
@@ -17,8 +18,13 @@
1718
from .accessors import register_dataset_method
1819
from .aster import extract_all_name_tokens
1920
from .categorical import _Categorical # noqa
21+
from .shared_memory import si_units
2022
from .table import Table
2123

24+
if TYPE_CHECKING:
25+
import openmatrix
26+
27+
2228
logger = logging.getLogger("sharrow")
2329

2430
well_known_names = {
@@ -283,7 +289,7 @@ def from_table(
283289

284290

285291
def from_omx(
286-
omx,
292+
omx: openmatrix.File,
287293
index_names=("otaz", "dtaz"),
288294
indexes="one-based",
289295
renames=None,
@@ -385,14 +391,23 @@ def from_omx(
385391
return xr.Dataset.from_dict(d)
386392

387393

394+
def _should_ignore(ignore, x):
395+
if ignore is not None:
396+
for i in ignore:
397+
if re.match(i, x):
398+
return True
399+
return False
400+
401+
388402
def from_omx_3d(
389-
omx,
403+
omx: openmatrix.File | str | Iterable[openmatrix.File | str],
390404
index_names=("otaz", "dtaz", "time_period"),
391405
indexes=None,
392406
*,
393407
time_periods=None,
394408
time_period_sep="__",
395409
max_float_precision=32,
410+
ignore=None,
396411
):
397412
"""
398413
Create a Dataset from an OMX file with an implicit third dimension.
@@ -427,6 +442,12 @@ def from_omx_3d(
427442
precision, generally to save memory if they were stored as double
428443
precision but that level of detail is unneeded in the present
429444
application.
445+
ignore : list-like, optional
446+
A list of regular expressions that will be used to filter out
447+
variables from the dataset. If any of the regular expressions
448+
match the name of a variable, that variable will not be included
449+
in the loaded dataset. This is useful for excluding variables that
450+
are not needed in the current application.
430451
431452
Returns
432453
-------
@@ -435,103 +456,216 @@ def from_omx_3d(
435456
if not isinstance(omx, (list, tuple)):
436457
omx = [omx]
437458

438-
# handle both larch.OMX and openmatrix.open_file versions
439-
if "larch" in type(omx[0]).__module__:
440-
omx_shape = omx[0].shape
441-
omx_lookup = omx[0].lookup
442-
else:
443-
omx_shape = omx[0].shape()
444-
omx_lookup = omx[0].root["lookup"]
445-
omx_data = []
446-
omx_data_map = {}
447-
for n, i in enumerate(omx):
448-
if "larch" in type(i).__module__:
449-
omx_data.append(i.data)
450-
for k in i.data._v_children:
451-
omx_data_map[k] = n
452-
else:
453-
omx_data.append(i.root["data"])
454-
for k in i.root["data"]._v_children:
455-
omx_data_map[k] = n
456-
457-
import dask.array
459+
use_file_handles = []
460+
opened_file_handles = []
461+
for filename in omx:
462+
if isinstance(filename, str):
463+
import openmatrix
458464

459-
data_names = list(omx_data_map.keys())
460-
n1, n2 = omx_shape
461-
if indexes is None:
462-
# default reads mapping if only one lookup is included, otherwise one-based
463-
if len(omx_lookup._v_children) == 1:
464-
ranger = None
465-
indexes = list(omx_lookup._v_children)[0]
465+
h = openmatrix.open_file(filename)
466+
opened_file_handles.append(h)
467+
use_file_handles.append(h)
468+
else:
469+
use_file_handles.append(filename)
470+
omx = use_file_handles
471+
472+
try:
473+
# handle both larch.OMX and openmatrix.open_file versions
474+
if "larch" in type(omx[0]).__module__:
475+
omx_shape = omx[0].shape
476+
omx_lookup = omx[0].lookup
466477
else:
478+
omx_shape = omx[0].shape()
479+
omx_lookup = omx[0].root["lookup"]
480+
omx_data = []
481+
omx_data_map = {}
482+
for n, i in enumerate(omx):
483+
if "larch" in type(i).__module__:
484+
omx_data.append(i.data)
485+
for k in i.data._v_children:
486+
omx_data_map[k] = n
487+
else:
488+
omx_data.append(i.root["data"])
489+
for k in i.root["data"]._v_children:
490+
omx_data_map[k] = n
491+
492+
import dask.array
493+
494+
data_names = list(omx_data_map.keys())
495+
if ignore is not None:
496+
if isinstance(ignore, str):
497+
ignore = [ignore]
498+
data_names = [i for i in data_names if not _should_ignore(ignore, i)]
499+
n1, n2 = omx_shape
500+
if indexes is None:
501+
# default reads mapping if only one lookup is included, otherwise one-based
502+
if len(omx_lookup._v_children) == 1:
503+
ranger = None
504+
indexes = list(omx_lookup._v_children)[0]
505+
else:
506+
ranger = one_based
507+
elif indexes == "one-based":
467508
ranger = one_based
468-
elif indexes == "one-based":
469-
ranger = one_based
470-
elif indexes == "zero-based":
471-
ranger = zero_based
472-
elif indexes in set(omx_lookup._v_children):
473-
ranger = None
474-
else:
475-
raise NotImplementedError(
476-
"only one-based, zero-based, and named indexes are implemented"
477-
)
478-
if ranger is not None:
479-
r1 = ranger(n1)
480-
r2 = ranger(n2)
481-
else:
482-
r1 = r2 = pd.Index(omx_lookup[indexes])
509+
elif indexes == "zero-based":
510+
ranger = zero_based
511+
elif indexes in set(omx_lookup._v_children):
512+
ranger = None
513+
else:
514+
raise NotImplementedError(
515+
"only one-based, zero-based, and named indexes are implemented"
516+
)
517+
if ranger is not None:
518+
r1 = ranger(n1)
519+
r2 = ranger(n2)
520+
else:
521+
r1 = r2 = pd.Index(omx_lookup[indexes])
483522

484-
if time_periods is None:
485-
raise ValueError("must give time periods explicitly")
523+
if time_periods is None:
524+
raise ValueError("must give time periods explicitly")
486525

487-
time_periods_map = {t: n for n, t in enumerate(time_periods)}
526+
time_periods_map = {t: n for n, t in enumerate(time_periods)}
488527

489-
pending_3d = {}
490-
content = {}
528+
pending_3d = {}
529+
content = {}
491530

492-
for k in data_names:
493-
if time_period_sep in k:
494-
base_k, time_k = k.split(time_period_sep, 1)
495-
if base_k not in pending_3d:
496-
pending_3d[base_k] = [None] * len(time_periods)
497-
pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(
498-
omx_data[omx_data_map[k]][k]
499-
)
500-
else:
501-
content[k] = xr.DataArray(
502-
dask.array.from_array(omx_data[omx_data_map[k]][k]),
503-
dims=index_names[:2],
531+
for k in data_names:
532+
if time_period_sep in k:
533+
base_k, time_k = k.split(time_period_sep, 1)
534+
if base_k not in pending_3d:
535+
pending_3d[base_k] = [None] * len(time_periods)
536+
pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(
537+
omx_data[omx_data_map[k]][k]
538+
)
539+
else:
540+
content[k] = xr.DataArray(
541+
dask.array.from_array(omx_data[omx_data_map[k]][k]),
542+
dims=index_names[:2],
543+
coords={
544+
index_names[0]: r1,
545+
index_names[1]: r2,
546+
},
547+
)
548+
for base_k, darrs in pending_3d.items():
549+
# find a prototype array
550+
prototype = None
551+
for i in darrs:
552+
prototype = i
553+
if prototype is not None:
554+
break
555+
if prototype is None:
556+
raise ValueError("no prototype")
557+
darrs_ = [
558+
(i if i is not None else dask.array.zeros_like(prototype))
559+
for i in darrs
560+
]
561+
content[base_k] = xr.DataArray(
562+
dask.array.stack(darrs_, axis=-1),
563+
dims=index_names,
504564
coords={
505565
index_names[0]: r1,
506566
index_names[1]: r2,
567+
index_names[2]: time_periods,
507568
},
508569
)
509-
for base_k, darrs in pending_3d.items():
510-
# find a prototype array
511-
prototype = None
512-
for i in darrs:
513-
prototype = i
514-
if prototype is not None:
515-
break
516-
if prototype is None:
517-
raise ValueError("no prototype")
518-
darrs_ = [
519-
(i if i is not None else dask.array.zeros_like(prototype)) for i in darrs
520-
]
521-
content[base_k] = xr.DataArray(
522-
dask.array.stack(darrs_, axis=-1),
523-
dims=index_names,
524-
coords={
525-
index_names[0]: r1,
526-
index_names[1]: r2,
527-
index_names[2]: time_periods,
528-
},
529-
)
530-
for i in content:
531-
if np.issubdtype(content[i].dtype, np.floating):
532-
if content[i].dtype.itemsize > max_float_precision / 8:
533-
content[i] = content[i].astype(f"float{max_float_precision}")
534-
return xr.Dataset(content)
570+
for i in content:
571+
if np.issubdtype(content[i].dtype, np.floating):
572+
if content[i].dtype.itemsize > max_float_precision / 8:
573+
content[i] = content[i].astype(f"float{max_float_precision}")
574+
return xr.Dataset(content)
575+
finally:
576+
for h in opened_file_handles:
577+
h.close()
578+
579+
580+
def reload_from_omx_3d(
581+
dataset: xr.Dataset,
582+
omx: Iterable[str],
583+
*,
584+
time_period_sep="__",
585+
ignore=None,
586+
) -> None:
587+
"""
588+
Reload the content of a dataset from OMX files.
589+
590+
This loads the data from the OMX files into the dataset, replacing
591+
the existing data in the dataset. The dataset must have been created
592+
by `from_omx_3d` or a similar function. Note that `from_omx_3d` will
593+
create a dataset backed by `dask.array` objects; this function allows for
594+
loading the data without going through dask, which may have poor performance
595+
on some platforms.
596+
597+
Parameters
598+
----------
599+
dataset : xr.Dataset
600+
The dataset to reload into.
601+
omx : Iterable[str]
602+
The list of OMX file names to load from.
603+
time_period_sep : str, default "__"
604+
The separator used to identify time periods in the dataset.
605+
ignore : list-like, optional
606+
A list of regular expressions that will be used to filter out
607+
variables from the dataset. If any of the regular expressions
608+
match the name of a variable, that variable will not be included
609+
in the load process. This is useful for excluding variables that
610+
are not found in the target dataset.
611+
"""
612+
if isinstance(ignore, str):
613+
ignore = [ignore]
614+
615+
use_file_handles = []
616+
opened_file_handles = []
617+
for filename in omx:
618+
if isinstance(filename, str):
619+
import openmatrix
620+
621+
h = openmatrix.open_file(filename)
622+
opened_file_handles.append(h)
623+
use_file_handles.append(h)
624+
else:
625+
use_file_handles.append(filename)
626+
omx = use_file_handles
627+
628+
bytes_loaded = 0
629+
630+
try:
631+
t0 = time.time()
632+
for filename, f in zip(omx, use_file_handles):
633+
if isinstance(filename, str):
634+
logger.info(f"loading into dataset from {filename}")
635+
for data_name in f.root.data._v_children:
636+
if _should_ignore(ignore, data_name):
637+
logger.info(f"ignoring {data_name}")
638+
continue
639+
t1 = time.time()
640+
filters = f.root.data[data_name].filters
641+
filter_note = f"{filters.complib}/{filters.complevel}"
642+
643+
if time_period_sep in data_name:
644+
data_name_x, data_name_t = data_name.split(time_period_sep, 1)
645+
if len(dataset[data_name_x].dims) != 3:
646+
raise ValueError(
647+
f"dataset variable {data_name_x} has "
648+
f"{len(dataset[data_name_x].dims)} dimensions, expected 3"
649+
)
650+
raw = dataset[data_name_x].sel(time_period=data_name_t).data
651+
raw[:, :] = f.root.data[data_name][:, :]
652+
else:
653+
if len(dataset[data_name].dims) != 2:
654+
raise ValueError(
655+
f"dataset variable {data_name} has "
656+
f"{len(dataset[data_name].dims)} dimensions, expected 2"
657+
)
658+
raw = dataset[data_name].data
659+
raw[:, :] = f.root.data[data_name][:, :]
660+
bytes_loaded += raw.nbytes
661+
logger.info(
662+
f"loaded {data_name} ({filter_note}) to dataset "
663+
f"in {time.time() - t1:.2f}s, {si_units(bytes_loaded)}"
664+
)
665+
logger.info(f"loading to dataset complete in {time.time() - t0:.2f}s")
666+
finally:
667+
for h in opened_file_handles:
668+
h.close()
535669

536670

537671
def from_amx(

0 commit comments

Comments
 (0)