55import hashlib
66import logging
77import re
8- from collections .abc import Hashable , Mapping , Sequence
9- from typing import Any
8+ import time
9+ from collections .abc import Hashable , Iterable , Mapping , Sequence
10+ from typing import TYPE_CHECKING , Any
1011
1112import numpy as np
1213import pandas as pd
1718from .accessors import register_dataset_method
1819from .aster import extract_all_name_tokens
1920from .categorical import _Categorical # noqa
21+ from .shared_memory import si_units
2022from .table import Table
2123
24+ if TYPE_CHECKING :
25+ import openmatrix
26+
27+
2228logger = logging .getLogger ("sharrow" )
2329
2430well_known_names = {
@@ -283,7 +289,7 @@ def from_table(
283289
284290
285291def from_omx (
286- omx ,
292+ omx : openmatrix . File ,
287293 index_names = ("otaz" , "dtaz" ),
288294 indexes = "one-based" ,
289295 renames = None ,
@@ -385,14 +391,23 @@ def from_omx(
385391 return xr .Dataset .from_dict (d )
386392
387393
394+ def _should_ignore (ignore , x ):
395+ if ignore is not None :
396+ for i in ignore :
397+ if re .match (i , x ):
398+ return True
399+ return False
400+
401+
388402def from_omx_3d (
389- omx ,
403+ omx : openmatrix . File | str | Iterable [ openmatrix . File | str ] ,
390404 index_names = ("otaz" , "dtaz" , "time_period" ),
391405 indexes = None ,
392406 * ,
393407 time_periods = None ,
394408 time_period_sep = "__" ,
395409 max_float_precision = 32 ,
410+ ignore = None ,
396411):
397412 """
398413 Create a Dataset from an OMX file with an implicit third dimension.
@@ -427,6 +442,12 @@ def from_omx_3d(
427442 precision, generally to save memory if they were stored as double
428443 precision but that level of detail is unneeded in the present
429444 application.
445+ ignore : list-like, optional
446+ A list of regular expressions that will be used to filter out
447+ variables from the dataset. If any of the regular expressions
448+ match the name of a variable, that variable will not be included
449+ in the loaded dataset. This is useful for excluding variables that
450+ are not needed in the current application.
430451
431452 Returns
432453 -------
@@ -435,103 +456,216 @@ def from_omx_3d(
435456 if not isinstance (omx , (list , tuple )):
436457 omx = [omx ]
437458
438- # handle both larch.OMX and openmatrix.open_file versions
439- if "larch" in type (omx [0 ]).__module__ :
440- omx_shape = omx [0 ].shape
441- omx_lookup = omx [0 ].lookup
442- else :
443- omx_shape = omx [0 ].shape ()
444- omx_lookup = omx [0 ].root ["lookup" ]
445- omx_data = []
446- omx_data_map = {}
447- for n , i in enumerate (omx ):
448- if "larch" in type (i ).__module__ :
449- omx_data .append (i .data )
450- for k in i .data ._v_children :
451- omx_data_map [k ] = n
452- else :
453- omx_data .append (i .root ["data" ])
454- for k in i .root ["data" ]._v_children :
455- omx_data_map [k ] = n
456-
457- import dask .array
459+ use_file_handles = []
460+ opened_file_handles = []
461+ for filename in omx :
462+ if isinstance (filename , str ):
463+ import openmatrix
458464
459- data_names = list (omx_data_map .keys ())
460- n1 , n2 = omx_shape
461- if indexes is None :
462- # default reads mapping if only one lookup is included, otherwise one-based
463- if len (omx_lookup ._v_children ) == 1 :
464- ranger = None
465- indexes = list (omx_lookup ._v_children )[0 ]
465+ h = openmatrix .open_file (filename )
466+ opened_file_handles .append (h )
467+ use_file_handles .append (h )
468+ else :
469+ use_file_handles .append (filename )
470+ omx = use_file_handles
471+
472+ try :
473+ # handle both larch.OMX and openmatrix.open_file versions
474+ if "larch" in type (omx [0 ]).__module__ :
475+ omx_shape = omx [0 ].shape
476+ omx_lookup = omx [0 ].lookup
466477 else :
478+ omx_shape = omx [0 ].shape ()
479+ omx_lookup = omx [0 ].root ["lookup" ]
480+ omx_data = []
481+ omx_data_map = {}
482+ for n , i in enumerate (omx ):
483+ if "larch" in type (i ).__module__ :
484+ omx_data .append (i .data )
485+ for k in i .data ._v_children :
486+ omx_data_map [k ] = n
487+ else :
488+ omx_data .append (i .root ["data" ])
489+ for k in i .root ["data" ]._v_children :
490+ omx_data_map [k ] = n
491+
492+ import dask .array
493+
494+ data_names = list (omx_data_map .keys ())
495+ if ignore is not None :
496+ if isinstance (ignore , str ):
497+ ignore = [ignore ]
498+ data_names = [i for i in data_names if not _should_ignore (ignore , i )]
499+ n1 , n2 = omx_shape
500+ if indexes is None :
501+ # default reads mapping if only one lookup is included, otherwise one-based
502+ if len (omx_lookup ._v_children ) == 1 :
503+ ranger = None
504+ indexes = list (omx_lookup ._v_children )[0 ]
505+ else :
506+ ranger = one_based
507+ elif indexes == "one-based" :
467508 ranger = one_based
468- elif indexes == "one-based" :
469- ranger = one_based
470- elif indexes == "zero-based" :
471- ranger = zero_based
472- elif indexes in set (omx_lookup ._v_children ):
473- ranger = None
474- else :
475- raise NotImplementedError (
476- "only one-based, zero-based, and named indexes are implemented"
477- )
478- if ranger is not None :
479- r1 = ranger (n1 )
480- r2 = ranger (n2 )
481- else :
482- r1 = r2 = pd .Index (omx_lookup [indexes ])
509+ elif indexes == "zero-based" :
510+ ranger = zero_based
511+ elif indexes in set (omx_lookup ._v_children ):
512+ ranger = None
513+ else :
514+ raise NotImplementedError (
515+ "only one-based, zero-based, and named indexes are implemented"
516+ )
517+ if ranger is not None :
518+ r1 = ranger (n1 )
519+ r2 = ranger (n2 )
520+ else :
521+ r1 = r2 = pd .Index (omx_lookup [indexes ])
483522
484- if time_periods is None :
485- raise ValueError ("must give time periods explicitly" )
523+ if time_periods is None :
524+ raise ValueError ("must give time periods explicitly" )
486525
487- time_periods_map = {t : n for n , t in enumerate (time_periods )}
526+ time_periods_map = {t : n for n , t in enumerate (time_periods )}
488527
489- pending_3d = {}
490- content = {}
528+ pending_3d = {}
529+ content = {}
491530
492- for k in data_names :
493- if time_period_sep in k :
494- base_k , time_k = k .split (time_period_sep , 1 )
495- if base_k not in pending_3d :
496- pending_3d [base_k ] = [None ] * len (time_periods )
497- pending_3d [base_k ][time_periods_map [time_k ]] = dask .array .from_array (
498- omx_data [omx_data_map [k ]][k ]
499- )
500- else :
501- content [k ] = xr .DataArray (
502- dask .array .from_array (omx_data [omx_data_map [k ]][k ]),
503- dims = index_names [:2 ],
531+ for k in data_names :
532+ if time_period_sep in k :
533+ base_k , time_k = k .split (time_period_sep , 1 )
534+ if base_k not in pending_3d :
535+ pending_3d [base_k ] = [None ] * len (time_periods )
536+ pending_3d [base_k ][time_periods_map [time_k ]] = dask .array .from_array (
537+ omx_data [omx_data_map [k ]][k ]
538+ )
539+ else :
540+ content [k ] = xr .DataArray (
541+ dask .array .from_array (omx_data [omx_data_map [k ]][k ]),
542+ dims = index_names [:2 ],
543+ coords = {
544+ index_names [0 ]: r1 ,
545+ index_names [1 ]: r2 ,
546+ },
547+ )
548+ for base_k , darrs in pending_3d .items ():
549+ # find a prototype array
550+ prototype = None
551+ for i in darrs :
552+ prototype = i
553+ if prototype is not None :
554+ break
555+ if prototype is None :
556+ raise ValueError ("no prototype" )
557+ darrs_ = [
558+ (i if i is not None else dask .array .zeros_like (prototype ))
559+ for i in darrs
560+ ]
561+ content [base_k ] = xr .DataArray (
562+ dask .array .stack (darrs_ , axis = - 1 ),
563+ dims = index_names ,
504564 coords = {
505565 index_names [0 ]: r1 ,
506566 index_names [1 ]: r2 ,
567+ index_names [2 ]: time_periods ,
507568 },
508569 )
509- for base_k , darrs in pending_3d .items ():
510- # find a prototype array
511- prototype = None
512- for i in darrs :
513- prototype = i
514- if prototype is not None :
515- break
516- if prototype is None :
517- raise ValueError ("no prototype" )
518- darrs_ = [
519- (i if i is not None else dask .array .zeros_like (prototype )) for i in darrs
520- ]
521- content [base_k ] = xr .DataArray (
522- dask .array .stack (darrs_ , axis = - 1 ),
523- dims = index_names ,
524- coords = {
525- index_names [0 ]: r1 ,
526- index_names [1 ]: r2 ,
527- index_names [2 ]: time_periods ,
528- },
529- )
530- for i in content :
531- if np .issubdtype (content [i ].dtype , np .floating ):
532- if content [i ].dtype .itemsize > max_float_precision / 8 :
533- content [i ] = content [i ].astype (f"float{ max_float_precision } " )
534- return xr .Dataset (content )
570+ for i in content :
571+ if np .issubdtype (content [i ].dtype , np .floating ):
572+ if content [i ].dtype .itemsize > max_float_precision / 8 :
573+ content [i ] = content [i ].astype (f"float{ max_float_precision } " )
574+ return xr .Dataset (content )
575+ finally :
576+ for h in opened_file_handles :
577+ h .close ()
578+
579+
580+ def reload_from_omx_3d (
581+ dataset : xr .Dataset ,
582+ omx : Iterable [str ],
583+ * ,
584+ time_period_sep = "__" ,
585+ ignore = None ,
586+ ) -> None :
587+ """
588+ Reload the content of a dataset from OMX files.
589+
590+ This loads the data from the OMX files into the dataset, replacing
591+ the existing data in the dataset. The dataset must have been created
592+ by `from_omx_3d` or a similar function. Note that `from_omx_3d` will
593+ create a dataset backed by `dask.array` objects; this function allows for
594+ loading the data without going through dask, which may have poor performance
595+ on some platforms.
596+
597+ Parameters
598+ ----------
599+ dataset : xr.Dataset
600+ The dataset to reload into.
601+ omx : Iterable[str]
602+ The list of OMX file names to load from.
603+ time_period_sep : str, default "__"
604+ The separator used to identify time periods in the dataset.
605+ ignore : list-like, optional
606+ A list of regular expressions that will be used to filter out
607+ variables from the dataset. If any of the regular expressions
608+ match the name of a variable, that variable will not be included
609+ in the load process. This is useful for excluding variables that
610+ are not found in the target dataset.
611+ """
612+ if isinstance (ignore , str ):
613+ ignore = [ignore ]
614+
615+ use_file_handles = []
616+ opened_file_handles = []
617+ for filename in omx :
618+ if isinstance (filename , str ):
619+ import openmatrix
620+
621+ h = openmatrix .open_file (filename )
622+ opened_file_handles .append (h )
623+ use_file_handles .append (h )
624+ else :
625+ use_file_handles .append (filename )
626+ omx = use_file_handles
627+
628+ bytes_loaded = 0
629+
630+ try :
631+ t0 = time .time ()
632+ for filename , f in zip (omx , use_file_handles ):
633+ if isinstance (filename , str ):
634+ logger .info (f"loading into dataset from { filename } " )
635+ for data_name in f .root .data ._v_children :
636+ if _should_ignore (ignore , data_name ):
637+ logger .info (f"ignoring { data_name } " )
638+ continue
639+ t1 = time .time ()
640+ filters = f .root .data [data_name ].filters
641+ filter_note = f"{ filters .complib } /{ filters .complevel } "
642+
643+ if time_period_sep in data_name :
644+ data_name_x , data_name_t = data_name .split (time_period_sep , 1 )
645+ if len (dataset [data_name_x ].dims ) != 3 :
646+ raise ValueError (
647+ f"dataset variable { data_name_x } has "
648+ f"{ len (dataset [data_name_x ].dims )} dimensions, expected 3"
649+ )
650+ raw = dataset [data_name_x ].sel (time_period = data_name_t ).data
651+ raw [:, :] = f .root .data [data_name ][:, :]
652+ else :
653+ if len (dataset [data_name ].dims ) != 2 :
654+ raise ValueError (
655+ f"dataset variable { data_name } has "
656+ f"{ len (dataset [data_name ].dims )} dimensions, expected 2"
657+ )
658+ raw = dataset [data_name ].data
659+ raw [:, :] = f .root .data [data_name ][:, :]
660+ bytes_loaded += raw .nbytes
661+ logger .info (
662+ f"loaded { data_name } ({ filter_note } ) to dataset "
663+ f"in { time .time () - t1 :.2f} s, { si_units (bytes_loaded )} "
664+ )
665+ logger .info (f"loading to dataset complete in { time .time () - t0 :.2f} s" )
666+ finally :
667+ for h in opened_file_handles :
668+ h .close ()
535669
536670
537671def from_amx (
0 commit comments