22import logging
33import warnings
44from collections .abc import Mapping , Sequence
5+ from numbers import Number
56from typing import Literal
67
78import networkx as nx
1011import xarray as xr
1112
1213from .dataset import Dataset , construct
13- from .tree_branch import DataTreeBranch
14+ from .tree_branch import CachedTree , DataTreeBranch
1415
1516try :
1617 from dask .array import Array as dask_array_type
@@ -898,15 +899,18 @@ def get_expr(
898899 * ,
899900 dtype = "float32" ,
900901 with_coords : bool = True ,
902+ parser : Literal ["pandas" , "python" ] = "pandas" ,
901903 ):
902904 """
903905 Access or evaluate an expression.
904906
905907 Parameters
906908 ----------
907909 expression : str
908- engine : {'sharrow', 'numexpr', 'python'}
909- The engine used to resolve expressions.
910+ engine : {'sharrow', 'numexpr', 'python', 'pandas-numexpr'}
911+ The engine used to resolve expressions. The numexpr engine uses
912+ that library directly, while the pandas-numexpr engine uses the
913+ pandas `eval` method with the numexpr engine.
910914 allow_native : bool, default True
911915 If the expression is an array in a dataset of this tree, return
912916 that array directly. Set to false to force evaluation, which
@@ -918,11 +922,19 @@ def get_expr(
918922 Attach coordinates from the root node of the tree to the result.
919923 If the coordinates are not needed in the result, the process
920924 of attaching them can be skipped.
925+ parser : {'pandas', 'python'}
926+ The parser to use when evaluating the expression. This argument
927+ only applies to pandas-based engines ('python' and 'pandas-numexpr').
928+ It is ignored when using the 'sharrow' or 'numexpr' engines.
921929
922930 Returns
923931 -------
924932 DataArray
925933 """
934+ if np .issubdtype (dtype , np .number ) and isinstance (dtype , type ):
935+ dtype = dtype .__name__
936+ elif dtype is bool :
937+ dtype = "bool"
926938 try :
927939 if allow_native :
928940 result = self [expression ]
@@ -938,16 +950,49 @@ def get_expr(
938950 .isel (expressions = 0 )
939951 )
940952 elif engine == "numexpr" :
953+ import numexpr as ne
954+ from xarray import DataArray
955+
956+ try :
957+ result = DataArray (
958+ ne .evaluate (expression , local_dict = CachedTree (self )),
959+ )
960+ except Exception :
961+ if dtype is None :
962+ dtype = "float32"
963+ result = (
964+ self .setup_flow ({expression : expression }, dtype = dtype )
965+ .load_dataarray ()
966+ .isel (expressions = 0 )
967+ )
968+ else :
969+ if dtype is not None :
970+ result = result .astype (dtype )
971+ # numexpr doesn't carry over the dimension names or coords
972+ result = result .rename (
973+ {result .dims [i ]: self .root_dims [i ] for i in range (result .ndim )}
974+ )
975+ if with_coords :
976+ result = result .assign_coords (self .root_dataset .coords )
977+
978+ elif engine == "pandas-numexpr" :
941979 from xarray import DataArray
942980
943981 self ._eval_cache = {}
944982 try :
945983 result = DataArray (
946- pd .eval (expression , resolvers = [self ], engine = "numexpr" ),
984+ pd .eval (
985+ expression ,
986+ resolvers = [self ],
987+ engine = "numexpr" ,
988+ parser = parser ,
989+ ),
947990 ).astype (dtype )
948991 except NotImplementedError :
949992 result = DataArray (
950- pd .eval (expression , resolvers = [self ], engine = "python" ),
993+ pd .eval (
994+ expression , resolvers = [self ], engine = "python" , parser = parser
995+ ),
951996 ).astype (dtype )
952997 else :
953998 # numexpr doesn't carry over the dimension names or coords
@@ -964,7 +1009,9 @@ def get_expr(
9641009 self ._eval_cache = {}
9651010 try :
9661011 result = DataArray (
967- pd .eval (expression , resolvers = [self ], engine = "python" ),
1012+ pd .eval (
1013+ expression , resolvers = [self ], engine = "python" , parser = parser
1014+ ),
9681015 ).astype (dtype )
9691016 finally :
9701017 del self ._eval_cache
@@ -974,7 +1021,7 @@ def get_expr(
9741021
9751022 def eval (
9761023 self ,
977- expression : str ,
1024+ expression : str | Number ,
9781025 engine : Literal [None , "numexpr" , "sharrow" , "python" ] = None ,
9791026 * ,
9801027 dtype : np .dtype | str | None = None ,
@@ -992,7 +1039,7 @@ def eval(
9921039
9931040 Parameters
9941041 ----------
995- expression : str
1042+ expression : str | Number
9961043 engine : {None, 'numexpr', 'sharrow', 'python'}
9971044 The engine used to resolve expressions. If None, the default is
9981045 to try 'numexpr' first, then 'sharrow' if that fails.
@@ -1007,33 +1054,45 @@ def eval(
10071054 -------
10081055 DataArray
10091056 """
1010- if not isinstance (expression , str ):
1011- raise TypeError ("expression must be a string" )
1012- if engine is None :
1013- try :
1014- result = self .get_expr (
1015- expression ,
1016- "numexpr" ,
1017- allow_native = False ,
1018- dtype = dtype ,
1019- with_coords = with_coords ,
1057+ # when passing in a numeric value or boolean, simply broadcast it to the root dims
1058+ if isinstance (expression , bool ):
1059+ expression = int (expression )
1060+ if isinstance (expression , Number ):
1061+ this_shape = [self .root_dataset .sizes .get (i ) for i in self .root_dims ]
1062+ result = xr .DataArray (
1063+ np .broadcast_to (expression , this_shape ), dims = self .root_dims
1064+ )
1065+ expression = str (expression )
1066+ else :
1067+ if not isinstance (expression , str ):
1068+ raise TypeError (
1069+ f"expression must be a string, not a { type (expression )} "
10201070 )
1021- except Exception :
1071+ if engine is None :
1072+ try :
1073+ result = self .get_expr (
1074+ expression ,
1075+ "numexpr" ,
1076+ allow_native = False ,
1077+ dtype = dtype ,
1078+ with_coords = with_coords ,
1079+ )
1080+ except Exception :
1081+ result = self .get_expr (
1082+ expression ,
1083+ "sharrow" ,
1084+ allow_native = False ,
1085+ dtype = dtype ,
1086+ with_coords = with_coords ,
1087+ )
1088+ else :
10221089 result = self .get_expr (
10231090 expression ,
1024- "sharrow" ,
1091+ engine ,
10251092 allow_native = False ,
10261093 dtype = dtype ,
10271094 with_coords = with_coords ,
10281095 )
1029- else :
1030- result = self .get_expr (
1031- expression ,
1032- engine ,
1033- allow_native = False ,
1034- dtype = dtype ,
1035- with_coords = with_coords ,
1036- )
10371096 if with_coords and "expressions" not in result .coords :
10381097 # add the expression as a scalar coordinate (with no dimension)
10391098 result = result .assign_coords (expressions = xr .DataArray (expression ))
@@ -1081,6 +1140,8 @@ def eval_many(
10811140 expressions = pd .Series (expressions , index = expressions )
10821141 if isinstance (expressions , Mapping ):
10831142 expressions = pd .Series (expressions )
1143+ if len (expressions ) == 0 :
1144+ raise ValueError ("no expressions provided" )
10841145 if result_type == "dataset" :
10851146 arrays = {}
10861147 for k , v in expressions .items ():
0 commit comments