Skip to content

Commit 52c594b

Browse files
authored
Raw numexpr engine (#70)
* ValueError on no expressions * allow number as expression in eval * broadcast expressions * raw numexpr engine * ruffen * numexpr dims * fix for when passing dtypes instead of dtype names * ruffen * fix subdtype * try again * explicitly include package data * change to use importlib.resources * ruffen * ruffen * hide test omx * don't test in editable mode
1 parent 652da47 commit 52c594b

File tree

9 files changed

+211
-56
lines changed

9 files changed

+211
-56
lines changed

.github/workflows/run-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
auto-update-conda: false
9393
- name: Install sharrow
9494
run: |
95-
python -m pip install -e .
95+
python -m pip install .
9696
- name: Conda checkup
9797
run: |
9898
conda info -a
@@ -137,7 +137,7 @@ jobs:
137137
conda install jupyter-book ruamel.yaml sphinx-autosummary-accessors -c conda-forge
138138
- name: Install sharrow
139139
run: |
140-
python -m pip install --no-deps -e .
140+
python -m pip install --no-deps .
141141
- name: Conda checkup
142142
run: |
143143
conda info -a

envs/development.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ dependencies:
1010
- filelock
1111
- ruff
1212
- jupyter
13-
- larch>=5.7.1
1413
- nbmake
1514
- networkx
1615
- notebook
@@ -29,4 +28,5 @@ dependencies:
2928
- zarr
3029

3130
- pip:
31+
- larch6
3232
- -e ..

envs/testing.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies:
2222
- pytest-xdist
2323
- nbmake
2424
- openmatrix
25+
- h5py
2526
- zarr
2627
- pip:
2728
- larch6

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ Repository = "https://github.com/activitysim/sharrow"
3737

3838
[tool.setuptools]
3939
packages = ["sharrow", "sharrow.utils"]
40+
include-package-data = true
41+
42+
[tool.setuptools.package-data]
43+
sharrow = ["*"]
4044

4145
[tool.setuptools_scm]
4246
fallback_version = "1999"

sharrow/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def from_table(
269269
result = xr.Dataset()
270270
if isinstance(index, pd.MultiIndex):
271271
dims = tuple(
272-
name if name is not None else "level_%i" % n
272+
name if name is not None else f"level_{n}"
273273
for n, name in enumerate(index.names)
274274
)
275275
for dim, lev in zip(dims, index.levels):

sharrow/example_data.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from importlib.resources import as_file, files
23

34
import numpy as np
45
import pandas as pd
@@ -9,17 +10,14 @@ def get_skims_filename() -> str:
910
return os.path.join(os.path.dirname(__file__), "example_data", "skims.omx")
1011

1112

12-
def get_skims():
13+
def get_skims_omx():
1314
import openmatrix
1415

1516
from . import dataset
1617

17-
zfilename = os.path.join(os.path.dirname(__file__), "example_data", "skims.zarr")
18-
if os.path.exists(zfilename):
19-
skims = dataset.from_zarr(zfilename, consolidated=False)
20-
else:
21-
filename = os.path.join(os.path.dirname(__file__), "example_data", "skims.omx")
22-
with openmatrix.open_file(filename) as f:
18+
with as_file(files("sharrow").joinpath("example_data/skims.omx")) as filename:
19+
skims = None
20+
with openmatrix.open_file(str(filename)) as f:
2321
skims = dataset.from_omx_3d(
2422
f,
2523
index_names=("otaz", "dtaz", "time_period"),
@@ -28,39 +26,56 @@ def get_skims():
2826
time_period_sep="__",
2927
max_float_precision=32,
3028
).compute()
31-
skims.to_zarr(zfilename)
29+
return skims
30+
31+
32+
def get_skims_zarr():
33+
from . import dataset
34+
35+
f = files("sharrow").joinpath("example_data/skims.zarr")
36+
with as_file(f) as zfile:
37+
if zfile.exists():
38+
skims = dataset.from_zarr(zfile, consolidated=False)
39+
else:
40+
skims = None
41+
return skims
42+
43+
44+
def get_skims():
45+
from . import dataset
46+
47+
f = files("sharrow").joinpath("example_data/skims.zarr")
48+
with as_file(f) as zfile:
49+
if zfile.exists():
50+
skims = dataset.from_zarr(zfile, consolidated=False)
51+
else:
52+
skims = get_skims_omx()
3253
return skims
3354

3455

3556
def get_households():
36-
filename = os.path.join(
37-
os.path.dirname(__file__), "example_data", "households.csv.gz"
38-
)
39-
return pd.read_csv(filename, index_col="HHID")
57+
with as_file(files("sharrow").joinpath("example_data/households.csv.gz")) as f:
58+
return pd.read_csv(f, index_col="HHID")
4059

4160

4261
def get_persons():
43-
filename = os.path.join(os.path.dirname(__file__), "example_data", "persons.csv.gz")
44-
return pd.read_csv(filename, index_col="PERID")
62+
with as_file(files("sharrow").joinpath("example_data/persons.csv.gz")) as f:
63+
return pd.read_csv(f, index_col="PERID")
4564

4665

4766
def get_land_use():
48-
filename = os.path.join(
49-
os.path.dirname(__file__), "example_data", "land_use.csv.gz"
50-
)
51-
return pd.read_csv(filename, index_col="TAZ")
67+
with as_file(files("sharrow").joinpath("example_data/land_use.csv.gz")) as f:
68+
return pd.read_csv(f, index_col="TAZ")
5269

5370

5471
def get_maz_to_taz():
55-
filename = os.path.join(os.path.dirname(__file__), "example_data", "maz_to_taz.csv")
56-
return pd.read_csv(filename, index_col="MAZ")
72+
with as_file(files("sharrow").joinpath("example_data/maz_to_taz.csv")) as f:
73+
return pd.read_csv(f, index_col="MAZ")
5774

5875

5976
def get_maz_to_maz_walk():
60-
filename = os.path.join(
61-
os.path.dirname(__file__), "example_data", "maz_to_maz_walk.csv"
62-
)
63-
return pd.read_csv(filename)
77+
with as_file(files("sharrow").joinpath("example_data/maz_to_maz_walk.csv")) as f:
78+
return pd.read_csv(f)
6479

6580

6681
def get_data():

sharrow/relationships.py

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import warnings
44
from collections.abc import Mapping, Sequence
5+
from numbers import Number
56
from typing import Literal
67

78
import networkx as nx
@@ -10,7 +11,7 @@
1011
import xarray as xr
1112

1213
from .dataset import Dataset, construct
13-
from .tree_branch import DataTreeBranch
14+
from .tree_branch import CachedTree, DataTreeBranch
1415

1516
try:
1617
from dask.array import Array as dask_array_type
@@ -898,15 +899,18 @@ def get_expr(
898899
*,
899900
dtype="float32",
900901
with_coords: bool = True,
902+
parser: Literal["pandas", "python"] = "pandas",
901903
):
902904
"""
903905
Access or evaluate an expression.
904906
905907
Parameters
906908
----------
907909
expression : str
908-
engine : {'sharrow', 'numexpr', 'python'}
909-
The engine used to resolve expressions.
910+
engine : {'sharrow', 'numexpr', 'python', 'pandas-numexpr'}
911+
The engine used to resolve expressions. The numexpr engine uses
912+
that library directly, while the pandas-numexpr engine uses the
913+
pandas `eval` method with the numexpr engine.
910914
allow_native : bool, default True
911915
If the expression is an array in a dataset of this tree, return
912916
that array directly. Set to false to force evaluation, which
@@ -918,11 +922,19 @@ def get_expr(
918922
Attach coordinates from the root node of the tree to the result.
919923
If the coordinates are not needed in the result, the process
920924
of attaching them can be skipped.
925+
parser : {'pandas', 'python'}
926+
The parser to use when evaluating the expression. This argument
927+
only applies to pandas-based engines ('python' and 'pandas-numexpr').
928+
It is ignored when using the 'sharrow' or 'numexpr' engines.
921929
922930
Returns
923931
-------
924932
DataArray
925933
"""
934+
if np.issubdtype(dtype, np.number) and isinstance(dtype, type):
935+
dtype = dtype.__name__
936+
elif dtype is bool:
937+
dtype = "bool"
926938
try:
927939
if allow_native:
928940
result = self[expression]
@@ -938,16 +950,49 @@ def get_expr(
938950
.isel(expressions=0)
939951
)
940952
elif engine == "numexpr":
953+
import numexpr as ne
954+
from xarray import DataArray
955+
956+
try:
957+
result = DataArray(
958+
ne.evaluate(expression, local_dict=CachedTree(self)),
959+
)
960+
except Exception:
961+
if dtype is None:
962+
dtype = "float32"
963+
result = (
964+
self.setup_flow({expression: expression}, dtype=dtype)
965+
.load_dataarray()
966+
.isel(expressions=0)
967+
)
968+
else:
969+
if dtype is not None:
970+
result = result.astype(dtype)
971+
# numexpr doesn't carry over the dimension names or coords
972+
result = result.rename(
973+
{result.dims[i]: self.root_dims[i] for i in range(result.ndim)}
974+
)
975+
if with_coords:
976+
result = result.assign_coords(self.root_dataset.coords)
977+
978+
elif engine == "pandas-numexpr":
941979
from xarray import DataArray
942980

943981
self._eval_cache = {}
944982
try:
945983
result = DataArray(
946-
pd.eval(expression, resolvers=[self], engine="numexpr"),
984+
pd.eval(
985+
expression,
986+
resolvers=[self],
987+
engine="numexpr",
988+
parser=parser,
989+
),
947990
).astype(dtype)
948991
except NotImplementedError:
949992
result = DataArray(
950-
pd.eval(expression, resolvers=[self], engine="python"),
993+
pd.eval(
994+
expression, resolvers=[self], engine="python", parser=parser
995+
),
951996
).astype(dtype)
952997
else:
953998
# numexpr doesn't carry over the dimension names or coords
@@ -964,7 +1009,9 @@ def get_expr(
9641009
self._eval_cache = {}
9651010
try:
9661011
result = DataArray(
967-
pd.eval(expression, resolvers=[self], engine="python"),
1012+
pd.eval(
1013+
expression, resolvers=[self], engine="python", parser=parser
1014+
),
9681015
).astype(dtype)
9691016
finally:
9701017
del self._eval_cache
@@ -974,7 +1021,7 @@ def get_expr(
9741021

9751022
def eval(
9761023
self,
977-
expression: str,
1024+
expression: str | Number,
9781025
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
9791026
*,
9801027
dtype: np.dtype | str | None = None,
@@ -992,7 +1039,7 @@ def eval(
9921039
9931040
Parameters
9941041
----------
995-
expression : str
1042+
expression : str | Number
9961043
engine : {None, 'numexpr', 'sharrow', 'python'}
9971044
The engine used to resolve expressions. If None, the default is
9981045
to try 'numexpr' first, then 'sharrow' if that fails.
@@ -1007,33 +1054,45 @@ def eval(
10071054
-------
10081055
DataArray
10091056
"""
1010-
if not isinstance(expression, str):
1011-
raise TypeError("expression must be a string")
1012-
if engine is None:
1013-
try:
1014-
result = self.get_expr(
1015-
expression,
1016-
"numexpr",
1017-
allow_native=False,
1018-
dtype=dtype,
1019-
with_coords=with_coords,
1057+
# when passing in a numeric value or boolean, simply broadcast it to the root dims
1058+
if isinstance(expression, bool):
1059+
expression = int(expression)
1060+
if isinstance(expression, Number):
1061+
this_shape = [self.root_dataset.sizes.get(i) for i in self.root_dims]
1062+
result = xr.DataArray(
1063+
np.broadcast_to(expression, this_shape), dims=self.root_dims
1064+
)
1065+
expression = str(expression)
1066+
else:
1067+
if not isinstance(expression, str):
1068+
raise TypeError(
1069+
f"expression must be a string, not a {type(expression)}"
10201070
)
1021-
except Exception:
1071+
if engine is None:
1072+
try:
1073+
result = self.get_expr(
1074+
expression,
1075+
"numexpr",
1076+
allow_native=False,
1077+
dtype=dtype,
1078+
with_coords=with_coords,
1079+
)
1080+
except Exception:
1081+
result = self.get_expr(
1082+
expression,
1083+
"sharrow",
1084+
allow_native=False,
1085+
dtype=dtype,
1086+
with_coords=with_coords,
1087+
)
1088+
else:
10221089
result = self.get_expr(
10231090
expression,
1024-
"sharrow",
1091+
engine,
10251092
allow_native=False,
10261093
dtype=dtype,
10271094
with_coords=with_coords,
10281095
)
1029-
else:
1030-
result = self.get_expr(
1031-
expression,
1032-
engine,
1033-
allow_native=False,
1034-
dtype=dtype,
1035-
with_coords=with_coords,
1036-
)
10371096
if with_coords and "expressions" not in result.coords:
10381097
# add the expression as a scalar coordinate (with no dimension)
10391098
result = result.assign_coords(expressions=xr.DataArray(expression))
@@ -1081,6 +1140,8 @@ def eval_many(
10811140
expressions = pd.Series(expressions, index=expressions)
10821141
if isinstance(expressions, Mapping):
10831142
expressions = pd.Series(expressions)
1143+
if len(expressions) == 0:
1144+
raise ValueError("no expressions provided")
10841145
if result_type == "dataset":
10851146
arrays = {}
10861147
for k, v in expressions.items():

0 commit comments

Comments
 (0)