11import ast
22import logging
33import warnings
4+ from collections .abc import Mapping , Sequence
5+ from typing import Literal
46
57import networkx as nx
68import numpy as np
79import pandas as pd
810import xarray as xr
911
1012from .dataset import Dataset , construct
13+ from .tree_branch import DataTreeBranch
1114
1215try :
1316 from dask .array import Array as dask_array_type
@@ -69,7 +72,10 @@ def _ixname():
6972 return f"index{ inum } "
7073
7174 for k , v in idxs .items ():
72- loaders [k ] = xr .DataArray (v , dims = [_ixname () for n in range (v .ndim )])
75+ if isinstance (v , xr .DataArray ):
76+ loaders [k ] = v
77+ else :
78+ loaders [k ] = xr .DataArray (v , dims = [_ixname () for n in range (v .ndim )])
7379 if _names :
7480 ds = source [_names ]
7581 else :
@@ -91,7 +97,10 @@ def _ixname():
9197 return f"index{ inum } "
9298
9399 for k , v in idxs .items ():
94- loaders [k ] = xr .DataArray (v , dims = [_ixname () for n in range (v .ndim )])
100+ if isinstance (v , xr .DataArray ):
101+ loaders [k ] = v
102+ else :
103+ loaders [k ] = xr .DataArray (v , dims = [_ixname () for n in range (v .ndim )])
95104 if _names :
96105 ds = source [_names ]
97106 else :
@@ -575,8 +584,6 @@ def add_dataset(self, name, dataset, relationships=(), as_root=False):
575584 self .digitize_relationships (inplace = True )
576585
577586 def add_items (self , items ):
578- from collections .abc import Mapping , Sequence
579-
580587 if isinstance (items , Sequence ):
581588 for i in items :
582589 self .add_items (i )
@@ -621,7 +628,15 @@ def _get_relationship(self, edge):
621628 )
622629
623630 def __getitem__ (self , item ):
624- return self .get (item )
631+ if hasattr (self , "_eval_cache" ) and item in self ._eval_cache :
632+ return self ._eval_cache [item ]
633+ try :
634+ return self .get (item )
635+ except KeyError as err :
636+ s = self ._graph .nodes .get (item , {}).get ("dataset" , None )
637+ if s is not None :
638+ return DataTreeBranch (self , item )
639+ raise err
625640
626641 def get (self , item , default = None , broadcast = True , coords = True ):
627642 """
@@ -687,6 +702,11 @@ def get(self, item, default=None, broadcast=True, coords=True):
687702 add_coords [i ] = base_dataset .coords [i ]
688703 if add_coords :
689704 result = result .assign_coords (add_coords )
705+ if broadcast :
706+ if self .dim_order is None :
707+ result = result .transpose (* self .root_dims )
708+ else :
709+ result = result .transpose (* self .dim_order )
690710 return result
691711
692712 def finditem (self , item , maybe_in = None ):
@@ -828,6 +848,32 @@ def _getitem(
828848 _positions [r .child_name ] = _idx
829849 if top_dim_name is not None :
830850 top_dim_names [r .child_name ] = top_dim_name
851+ if len (top_dim_names ) > 1 :
852+ if len (set (top_dim_names .values ())) == 1 :
853+ # capture the situation where all top dims are the same
854+ _positions = {
855+ k : xr .DataArray (v , dims = [top_dim_names [k ]])
856+ for (k , v ) in _positions .items ()
857+ }
858+ _labels = {
859+ k : xr .DataArray (v , dims = [top_dim_names [k ]])
860+ for (k , v ) in _labels .items ()
861+ }
862+ # the top dim names have served their purpose, so clear them
863+ top_dim_names = {}
864+ elif len (set (top_dim_names .values ())) < len (top_dim_names ):
865+ # capture the situation where some but not all top dims are the same
866+ # same as above?
867+ _positions = {
868+ k : xr .DataArray (v , dims = [top_dim_names [k ]])
869+ for (k , v ) in _positions .items ()
870+ }
871+ _labels = {
872+ k : xr .DataArray (v , dims = [top_dim_names [k ]])
873+ for (k , v ) in _labels .items ()
874+ }
875+ # the top dim names have served their purpose, so clear them
876+ top_dim_names = {}
831877 y = xgather (result , _positions , _labels )
832878 if len (result .dims ) == 1 and len (y .dims ) == 1 :
833879 y = y .rename ({y .dims [0 ]: result .dims [0 ]})
@@ -844,19 +890,34 @@ def _getitem(
844890
845891 raise KeyError (item )
846892
847- def get_expr (self , expression , engine = "sharrow" , allow_native = True ):
893+ def get_expr (
894+ self ,
895+ expression ,
896+ engine = "sharrow" ,
897+ allow_native = True ,
898+ * ,
899+ dtype = "float32" ,
900+ with_coords : bool = True ,
901+ ):
848902 """
849903 Access or evaluate an expression.
850904
851905 Parameters
852906 ----------
853907 expression : str
854- engine : {'sharrow', 'numexpr'}
908+ engine : {'sharrow', 'numexpr', 'python' }
855909 The engine used to resolve expressions.
856910 allow_native : bool, default True
857911 If the expression is an array in a dataset of this tree, return
858912 that array directly. Set to false to force evaluation, which
859913 will also ensure proper broadcasting consistent with this data tree.
914+ dtype : str or dtype, default 'float32'
915+ The dtype to use when creating new arrays. This only applies when
916+ the expression is not returned as a native variable from the tree.
917+ with_coords : bool, default True
918+ Attach coordinates from the root node of the tree to the result.
919+ If the coordinates are not needed in the result, the process
920+ of attaching them can be skipped.
860921
861922 Returns
862923 -------
@@ -869,21 +930,185 @@ def get_expr(self, expression, engine="sharrow", allow_native=True):
869930 raise KeyError
870931 except (KeyError , IndexError ):
871932 if engine == "sharrow" :
933+ if dtype is None :
934+ dtype = "float32"
872935 result = (
873- self .setup_flow ({expression : expression })
936+ self .setup_flow ({expression : expression }, dtype = dtype )
874937 .load_dataarray ()
875938 .isel (expressions = 0 )
876939 )
877940 elif engine == "numexpr" :
878941 from xarray import DataArray
879942
880- result = DataArray (
881- pd .eval (expression , resolvers = [self ], engine = "numexpr" ),
882- )
943+ self ._eval_cache = {}
944+ try :
945+ result = DataArray (
946+ pd .eval (expression , resolvers = [self ], engine = "numexpr" ),
947+ ).astype (dtype )
948+ except NotImplementedError :
949+ result = DataArray (
950+ pd .eval (expression , resolvers = [self ], engine = "python" ),
951+ ).astype (dtype )
952+ else :
953+ # numexpr doesn't carry over the dimension names or coords
954+ result = result .rename (
955+ {result .dims [i ]: self .root_dims [i ] for i in range (result .ndim )}
956+ )
957+ if with_coords :
958+ result = result .assign_coords (self .root_dataset .coords )
959+ finally :
960+ del self ._eval_cache
961+ elif engine == "python" :
962+ from xarray import DataArray
963+
964+ self ._eval_cache = {}
965+ try :
966+ result = DataArray (
967+ pd .eval (expression , resolvers = [self ], engine = "python" ),
968+ ).astype (dtype )
969+ finally :
970+ del self ._eval_cache
883971 else :
884972 raise ValueError (f"unknown engine { engine } " ) from None
885973 return result
886974
975+ def eval (
976+ self ,
977+ expression : str ,
978+ engine : Literal [None , "numexpr" , "sharrow" , "python" ] = None ,
979+ * ,
980+ dtype : np .dtype | str | None = None ,
981+ name : str | None = None ,
982+ with_coords : bool = True ,
983+ ):
984+ """
985+ Evaluate an expression.
986+
987+ The resulting DataArray will have dimensions that match the root
988+ Dataset of this tree, and the content will be broadcast to those
989+ dimensions if necessary. The expression evaluated will be assigned
990+ as a scalar coordinate named 'expressions', to facilitate concatenation
991+ with other `eval` results if desired.
992+
993+ Parameters
994+ ----------
995+ expression : str
996+ engine : {None, 'numexpr', 'sharrow', 'python'}
997+ The engine used to resolve expressions. If None, the default is
998+ to try 'numexpr' first, then 'sharrow' if that fails.
999+ dtype : str or dtype, optional
1000+ The dtype to use for the result. If the engine is `sharrow` and
1001+ no value is given, this will default to `float32`, otherwise the
1002+ default is to use the dtype of the result of the expression.
1003+ name : str, optional
1004+ The name to give the resulting DataArray.
1005+
1006+ Returns
1007+ -------
1008+ DataArray
1009+ """
1010+ if not isinstance (expression , str ):
1011+ raise TypeError ("expression must be a string" )
1012+ if engine is None :
1013+ try :
1014+ result = self .get_expr (
1015+ expression ,
1016+ "numexpr" ,
1017+ allow_native = False ,
1018+ dtype = dtype ,
1019+ with_coords = with_coords ,
1020+ )
1021+ except Exception :
1022+ result = self .get_expr (
1023+ expression ,
1024+ "sharrow" ,
1025+ allow_native = False ,
1026+ dtype = dtype ,
1027+ with_coords = with_coords ,
1028+ )
1029+ else :
1030+ result = self .get_expr (
1031+ expression ,
1032+ engine ,
1033+ allow_native = False ,
1034+ dtype = dtype ,
1035+ with_coords = with_coords ,
1036+ )
1037+ if with_coords and "expressions" not in result .coords :
1038+ # add the expression as a scalar coordinate (with no dimension)
1039+ result = result .assign_coords (expressions = xr .DataArray (expression ))
1040+ if name is not None :
1041+ result .name = name
1042+ return result
1043+
1044+ def eval_many (
1045+ self ,
1046+ expressions : Sequence [str ] | Mapping [str , str ] | pd .Series ,
1047+ * ,
1048+ engine : Literal [None , "numexpr" , "sharrow" , "python" ] = None ,
1049+ dtype = None ,
1050+ result_type : Literal ["dataset" , "dataarray" ] = "dataset" ,
1051+ with_coords : bool = True ,
1052+ ) -> xr .Dataset | xr .DataArray :
1053+ """
1054+ Evaluate multiple expressions.
1055+
1056+ Parameters
1057+ ----------
1058+ expressions : Sequence[str] or Mapping[str,str] or pd.Series
1059+ The expressions to evaluate. If a sequence, the names of the
1060+ resulting DataArrays will be the same as the expressions. If a
1061+ mapping or Series, the keys or index will be used as the names.
1062+ engine : {None, 'numexpr', 'sharrow', 'python'}
1063+ The engine used to resolve expressions. If None, the default is to
1064+ try 'numexpr' first, then 'sharrow' if that fails.
1065+ dtype : str or dtype, optional
1066+ The dtype to use for the result. If the engine is `sharrow` and
1067+ no value is given, this will default to `float32`, otherwise the
1068+ default is to use the dtype of the result of the concatenation of
1069+ the expressions.
1070+ result_type : {'dataset', 'dataarray'}
1071+ Whether to return a Dataset (with a variable for each expression)
1072+ or a DataArray (with a dimension across all expressions).
1073+
1074+ Returns
1075+ -------
1076+ Dataset or DataArray
1077+ """
1078+ if result_type not in {"dataset" , "dataarray" }:
1079+ raise ValueError ("result_type must be one of ['dataset', 'dataarray']" )
1080+ if not isinstance (expressions , (Mapping , pd .Series )):
1081+ expressions = pd .Series (expressions , index = expressions )
1082+ if isinstance (expressions , Mapping ):
1083+ expressions = pd .Series (expressions )
1084+ if result_type == "dataset" :
1085+ arrays = {}
1086+ for k , v in expressions .items ():
1087+ a = self .eval (
1088+ v , engine = engine , dtype = dtype , name = k , with_coords = with_coords
1089+ )
1090+ if "expressions" in a .coords :
1091+ a = a .drop_vars ("expressions" )
1092+ arrays [k ] = a .assign_attrs (expression = v )
1093+ result = xr .Dataset (arrays )
1094+ else :
1095+ arrays = {}
1096+ for k , v in expressions .items ():
1097+ a = self .eval (
1098+ v , engine = engine , dtype = dtype , name = k , with_coords = with_coords
1099+ )
1100+ if "expressions" in a .coords :
1101+ a = a .drop_vars ("expressions" )
1102+ a = a .expand_dims ("expressions" , - 1 )
1103+ arrays [k ] = a
1104+ result = xr .concat (list (arrays .values ()), "expressions" )
1105+ if with_coords :
1106+ result = result .assign_coords (
1107+ expressions = expressions .index ,
1108+ source = xr .DataArray (expressions .values , dims = "expressions" ),
1109+ )
1110+ return result
1111+
8871112 @property
8881113 def subspaces (self ):
8891114 """Mapping[str,Dataset] : Direct access to node Dataset objects by name."""
@@ -1583,3 +1808,19 @@ def merged_dataset(self, columns=None, uniquify=False):
15831808 if coords :
15841809 result .assign_coords (coords )
15851810 return result
1811+
1812+ def __iter__ (self ):
1813+ """Iterate over all the datasets."""
1814+ import itertools
1815+
1816+ if hasattr (self , "_eval_cache" ):
1817+ z = (self ._eval_cache ,)
1818+ else :
1819+ z = ()
1820+ return itertools .chain (* z , * (v for k , v in self .subspaces_iter ()))
1821+
1822+ def __setitem__ (self , key , value ):
1823+ if hasattr (self , "_eval_cache" ):
1824+ self ._eval_cache [key ] = value
1825+ else :
1826+ raise NotImplementedError ("setitem not supported" )
0 commit comments