Skip to content

Commit 652da47

Browse files
authored
More options for eval engine (#68)
* dtype arg for get_expr * improve get_expr * add DataTree.eval * eval_many * remove rogue print * with_coords * enable dotted names with numexpr and python engines * fix dimension order with numexpr * default to numexpr usually
1 parent 113ba37 commit 652da47

File tree

4 files changed

+1994
-11
lines changed

4 files changed

+1994
-11
lines changed

sharrow/relationships.py

Lines changed: 252 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import ast
22
import logging
33
import warnings
4+
from collections.abc import Mapping, Sequence
5+
from typing import Literal
46

57
import networkx as nx
68
import numpy as np
79
import pandas as pd
810
import xarray as xr
911

1012
from .dataset import Dataset, construct
13+
from .tree_branch import DataTreeBranch
1114

1215
try:
1316
from dask.array import Array as dask_array_type
@@ -69,7 +72,10 @@ def _ixname():
6972
return f"index{inum}"
7073

7174
for k, v in idxs.items():
72-
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
75+
if isinstance(v, xr.DataArray):
76+
loaders[k] = v
77+
else:
78+
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
7379
if _names:
7480
ds = source[_names]
7581
else:
@@ -91,7 +97,10 @@ def _ixname():
9197
return f"index{inum}"
9298

9399
for k, v in idxs.items():
94-
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
100+
if isinstance(v, xr.DataArray):
101+
loaders[k] = v
102+
else:
103+
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
95104
if _names:
96105
ds = source[_names]
97106
else:
@@ -575,8 +584,6 @@ def add_dataset(self, name, dataset, relationships=(), as_root=False):
575584
self.digitize_relationships(inplace=True)
576585

577586
def add_items(self, items):
578-
from collections.abc import Mapping, Sequence
579-
580587
if isinstance(items, Sequence):
581588
for i in items:
582589
self.add_items(i)
@@ -621,7 +628,15 @@ def _get_relationship(self, edge):
621628
)
622629

623630
def __getitem__(self, item):
624-
return self.get(item)
631+
if hasattr(self, "_eval_cache") and item in self._eval_cache:
632+
return self._eval_cache[item]
633+
try:
634+
return self.get(item)
635+
except KeyError as err:
636+
s = self._graph.nodes.get(item, {}).get("dataset", None)
637+
if s is not None:
638+
return DataTreeBranch(self, item)
639+
raise err
625640

626641
def get(self, item, default=None, broadcast=True, coords=True):
627642
"""
@@ -687,6 +702,11 @@ def get(self, item, default=None, broadcast=True, coords=True):
687702
add_coords[i] = base_dataset.coords[i]
688703
if add_coords:
689704
result = result.assign_coords(add_coords)
705+
if broadcast:
706+
if self.dim_order is None:
707+
result = result.transpose(*self.root_dims)
708+
else:
709+
result = result.transpose(*self.dim_order)
690710
return result
691711

692712
def finditem(self, item, maybe_in=None):
@@ -828,6 +848,32 @@ def _getitem(
828848
_positions[r.child_name] = _idx
829849
if top_dim_name is not None:
830850
top_dim_names[r.child_name] = top_dim_name
851+
if len(top_dim_names) > 1:
852+
if len(set(top_dim_names.values())) == 1:
853+
# capture the situation where all top dims are the same
854+
_positions = {
855+
k: xr.DataArray(v, dims=[top_dim_names[k]])
856+
for (k, v) in _positions.items()
857+
}
858+
_labels = {
859+
k: xr.DataArray(v, dims=[top_dim_names[k]])
860+
for (k, v) in _labels.items()
861+
}
862+
# the top dim names have served their purpose, so clear them
863+
top_dim_names = {}
864+
elif len(set(top_dim_names.values())) < len(top_dim_names):
865+
# capture the situation where some but not all top dims are the same
866+
# same as above?
867+
_positions = {
868+
k: xr.DataArray(v, dims=[top_dim_names[k]])
869+
for (k, v) in _positions.items()
870+
}
871+
_labels = {
872+
k: xr.DataArray(v, dims=[top_dim_names[k]])
873+
for (k, v) in _labels.items()
874+
}
875+
# the top dim names have served their purpose, so clear them
876+
top_dim_names = {}
831877
y = xgather(result, _positions, _labels)
832878
if len(result.dims) == 1 and len(y.dims) == 1:
833879
y = y.rename({y.dims[0]: result.dims[0]})
@@ -844,19 +890,34 @@ def _getitem(
844890

845891
raise KeyError(item)
846892

847-
def get_expr(self, expression, engine="sharrow", allow_native=True):
893+
def get_expr(
894+
self,
895+
expression,
896+
engine="sharrow",
897+
allow_native=True,
898+
*,
899+
dtype="float32",
900+
with_coords: bool = True,
901+
):
848902
"""
849903
Access or evaluate an expression.
850904
851905
Parameters
852906
----------
853907
expression : str
854-
engine : {'sharrow', 'numexpr'}
908+
engine : {'sharrow', 'numexpr', 'python'}
855909
The engine used to resolve expressions.
856910
allow_native : bool, default True
857911
If the expression is an array in a dataset of this tree, return
858912
that array directly. Set to false to force evaluation, which
859913
will also ensure proper broadcasting consistent with this data tree.
914+
dtype : str or dtype, default 'float32'
915+
The dtype to use when creating new arrays. This only applies when
916+
the expression is not returned as a native variable from the tree.
917+
with_coords : bool, default True
918+
Attach coordinates from the root node of the tree to the result.
919+
If the coordinates are not needed in the result, the process
920+
of attaching them can be skipped.
860921
861922
Returns
862923
-------
@@ -869,21 +930,185 @@ def get_expr(self, expression, engine="sharrow", allow_native=True):
869930
raise KeyError
870931
except (KeyError, IndexError):
871932
if engine == "sharrow":
933+
if dtype is None:
934+
dtype = "float32"
872935
result = (
873-
self.setup_flow({expression: expression})
936+
self.setup_flow({expression: expression}, dtype=dtype)
874937
.load_dataarray()
875938
.isel(expressions=0)
876939
)
877940
elif engine == "numexpr":
878941
from xarray import DataArray
879942

880-
result = DataArray(
881-
pd.eval(expression, resolvers=[self], engine="numexpr"),
882-
)
943+
self._eval_cache = {}
944+
try:
945+
result = DataArray(
946+
pd.eval(expression, resolvers=[self], engine="numexpr"),
947+
).astype(dtype)
948+
except NotImplementedError:
949+
result = DataArray(
950+
pd.eval(expression, resolvers=[self], engine="python"),
951+
).astype(dtype)
952+
else:
953+
# numexpr doesn't carry over the dimension names or coords
954+
result = result.rename(
955+
{result.dims[i]: self.root_dims[i] for i in range(result.ndim)}
956+
)
957+
if with_coords:
958+
result = result.assign_coords(self.root_dataset.coords)
959+
finally:
960+
del self._eval_cache
961+
elif engine == "python":
962+
from xarray import DataArray
963+
964+
self._eval_cache = {}
965+
try:
966+
result = DataArray(
967+
pd.eval(expression, resolvers=[self], engine="python"),
968+
).astype(dtype)
969+
finally:
970+
del self._eval_cache
883971
else:
884972
raise ValueError(f"unknown engine {engine}") from None
885973
return result
886974

975+
def eval(
976+
self,
977+
expression: str,
978+
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
979+
*,
980+
dtype: np.dtype | str | None = None,
981+
name: str | None = None,
982+
with_coords: bool = True,
983+
):
984+
"""
985+
Evaluate an expression.
986+
987+
The resulting DataArray will have dimensions that match the root
988+
Dataset of this tree, and the content will be broadcast to those
989+
dimensions if necessary. The expression evaluated will be assigned
990+
as a scalar coordinate named 'expressions', to facilitate concatenation
991+
with other `eval` results if desired.
992+
993+
Parameters
994+
----------
995+
expression : str
996+
engine : {None, 'numexpr', 'sharrow', 'python'}
997+
The engine used to resolve expressions. If None, the default is
998+
to try 'numexpr' first, then 'sharrow' if that fails.
999+
dtype : str or dtype, optional
1000+
The dtype to use for the result. If the engine is `sharrow` and
1001+
no value is given, this will default to `float32`, otherwise the
1002+
default is to use the dtype of the result of the expression.
1003+
name : str, optional
1004+
The name to give the resulting DataArray.
1005+
1006+
Returns
1007+
-------
1008+
DataArray
1009+
"""
1010+
if not isinstance(expression, str):
1011+
raise TypeError("expression must be a string")
1012+
if engine is None:
1013+
try:
1014+
result = self.get_expr(
1015+
expression,
1016+
"numexpr",
1017+
allow_native=False,
1018+
dtype=dtype,
1019+
with_coords=with_coords,
1020+
)
1021+
except Exception:
1022+
result = self.get_expr(
1023+
expression,
1024+
"sharrow",
1025+
allow_native=False,
1026+
dtype=dtype,
1027+
with_coords=with_coords,
1028+
)
1029+
else:
1030+
result = self.get_expr(
1031+
expression,
1032+
engine,
1033+
allow_native=False,
1034+
dtype=dtype,
1035+
with_coords=with_coords,
1036+
)
1037+
if with_coords and "expressions" not in result.coords:
1038+
# add the expression as a scalar coordinate (with no dimension)
1039+
result = result.assign_coords(expressions=xr.DataArray(expression))
1040+
if name is not None:
1041+
result.name = name
1042+
return result
1043+
1044+
def eval_many(
1045+
self,
1046+
expressions: Sequence[str] | Mapping[str, str] | pd.Series,
1047+
*,
1048+
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
1049+
dtype=None,
1050+
result_type: Literal["dataset", "dataarray"] = "dataset",
1051+
with_coords: bool = True,
1052+
) -> xr.Dataset | xr.DataArray:
1053+
"""
1054+
Evaluate multiple expressions.
1055+
1056+
Parameters
1057+
----------
1058+
expressions : Sequence[str] or Mapping[str,str] or pd.Series
1059+
The expressions to evaluate. If a sequence, the names of the
1060+
resulting DataArrays will be the same as the expressions. If a
1061+
mapping or Series, the keys or index will be used as the names.
1062+
engine : {None, 'numexpr', 'sharrow', 'python'}
1063+
The engine used to resolve expressions. If None, the default is to
1064+
try 'numexpr' first, then 'sharrow' if that fails.
1065+
dtype : str or dtype, optional
1066+
The dtype to use for the result. If the engine is `sharrow` and
1067+
no value is given, this will default to `float32`, otherwise the
1068+
default is to use the dtype of the result of the concatenation of
1069+
the expressions.
1070+
result_type : {'dataset', 'dataarray'}
1071+
Whether to return a Dataset (with a variable for each expression)
1072+
or a DataArray (with a dimension across all expressions).
1073+
1074+
Returns
1075+
-------
1076+
Dataset or DataArray
1077+
"""
1078+
if result_type not in {"dataset", "dataarray"}:
1079+
raise ValueError("result_type must be one of ['dataset', 'dataarray']")
1080+
if not isinstance(expressions, (Mapping, pd.Series)):
1081+
expressions = pd.Series(expressions, index=expressions)
1082+
if isinstance(expressions, Mapping):
1083+
expressions = pd.Series(expressions)
1084+
if result_type == "dataset":
1085+
arrays = {}
1086+
for k, v in expressions.items():
1087+
a = self.eval(
1088+
v, engine=engine, dtype=dtype, name=k, with_coords=with_coords
1089+
)
1090+
if "expressions" in a.coords:
1091+
a = a.drop_vars("expressions")
1092+
arrays[k] = a.assign_attrs(expression=v)
1093+
result = xr.Dataset(arrays)
1094+
else:
1095+
arrays = {}
1096+
for k, v in expressions.items():
1097+
a = self.eval(
1098+
v, engine=engine, dtype=dtype, name=k, with_coords=with_coords
1099+
)
1100+
if "expressions" in a.coords:
1101+
a = a.drop_vars("expressions")
1102+
a = a.expand_dims("expressions", -1)
1103+
arrays[k] = a
1104+
result = xr.concat(list(arrays.values()), "expressions")
1105+
if with_coords:
1106+
result = result.assign_coords(
1107+
expressions=expressions.index,
1108+
source=xr.DataArray(expressions.values, dims="expressions"),
1109+
)
1110+
return result
1111+
8871112
@property
8881113
def subspaces(self):
8891114
"""Mapping[str,Dataset] : Direct access to node Dataset objects by name."""
@@ -1583,3 +1808,19 @@ def merged_dataset(self, columns=None, uniquify=False):
15831808
if coords:
15841809
result.assign_coords(coords)
15851810
return result
1811+
1812+
def __iter__(self):
1813+
"""Iterate over all the datasets."""
1814+
import itertools
1815+
1816+
if hasattr(self, "_eval_cache"):
1817+
z = (self._eval_cache,)
1818+
else:
1819+
z = ()
1820+
return itertools.chain(*z, *(v for k, v in self.subspaces_iter()))
1821+
1822+
def __setitem__(self, key, value):
1823+
if hasattr(self, "_eval_cache"):
1824+
self._eval_cache[key] = value
1825+
else:
1826+
raise NotImplementedError("setitem not supported")

0 commit comments

Comments
 (0)