Skip to content

Commit f1526ac

Browse files
max-sixtyclaude
andauthored
Dataset.eval works with >2 dims (#11064)
* Replace pandas.eval with native implementation This commit removes the dependency on pandas.eval() and implements a native expression evaluator in Dataset.eval() using Python's ast module. The new implementation provides better support for multi-dimensional arrays and maintains backward compatibility with deprecated operators through automatic transformation. Key changes: - Remove pd.eval() call and replace with custom _eval_expression() method - Add _LogicalOperatorTransformer to convert deprecated operators (and/or/not) to bitwise operators (&/|/~) that work element-wise on arrays - Implement automatic transformation of chained comparisons to explicit bitwise AND operations - Add security validation to block lambda expressions and private attributes - Emit FutureWarning for deprecated constructs (logical operators, chained comparisons, parser= argument) - Support assignment statements (target = expression) in eval() - Make data variables and coordinates take priority in namespace resolution - Provide safe builtins (abs, min, max, round, len, sum, pow, any, all, type constructors, iteration helpers) while blocking __import__, open, etc. - Add comprehensive test coverage including edge cases, error messages, dask compatibility, and security validation * Fix mypy errors in eval tests - Use pd.isna(ds["a"].values) instead of pd.isna(ds["a"]) since pandas type stubs don't have overloads for DataArray - Use abs() instead of np.abs() to get DataArray return type Co-authored-by: Claude <noreply@anthropic.com> * Remove security framing, frame restrictions as pd.eval() compatibility The lambda and dunder restrictions emulate pd.eval() behavior rather than providing security guarantees. Pandas explicitly doesn't claim these as security measures. Co-authored-by: Claude <noreply@anthropic.com> * Move eval implementation to dedicated module Extract AST-based expression evaluation code to xarray/core/eval.py: - EVAL_BUILTINS dict - LogicalOperatorTransformer class - validate_expression function This addresses the review feedback to keep the Dataset class focused. Co-authored-by: Claude <noreply@anthropic.com> * Move eval tests to dedicated test_eval.py module Extract eval tests from test_dataset.py to test_eval.py: - 35 tests covering basic functionality, error messages, edge cases, and dask - Mirrors the implementation structure (core/eval.py <-> tests/test_eval.py) - Reduces test_dataset.py by 574 lines Co-authored-by: Claude <noreply@anthropic.com> * Refactor eval tests: convert classes to standalone functions Address review feedback: - Convert TestEvalErrorMessages class to test_eval_error_* functions - Convert TestEvalEdgeCases class to test_eval_* functions - Convert TestEvalDask class to test_eval_dask_* functions This follows xarray's preference for standalone test functions over classes. Co-authored-by: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent ef82c56 commit f1526ac

File tree

4 files changed

+838
-32
lines changed

4 files changed

+838
-32
lines changed

xarray/core/dataset.py

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import ast
34
import asyncio
5+
import builtins
46
import copy
57
import datetime
68
import io
@@ -51,6 +53,11 @@
5153
from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer
5254
from xarray.core.dataset_variables import DataVariables
5355
from xarray.core.duck_array_ops import datetime_to_numeric
56+
from xarray.core.eval import (
57+
EVAL_BUILTINS,
58+
LogicalOperatorTransformer,
59+
validate_expression,
60+
)
5461
from xarray.core.indexes import (
5562
Index,
5663
Indexes,
@@ -72,7 +79,6 @@
7279
Self,
7380
T_ChunkDim,
7481
T_ChunksFreq,
75-
T_DataArray,
7682
T_DataArrayOrSet,
7783
ZarrWriteModes,
7884
)
@@ -9533,19 +9539,48 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
95339539
"Dataset.argmin() with a sequence or ... for dim"
95349540
)
95359541

9542+
def _eval_expression(self, expr: str) -> DataArray:
9543+
"""Evaluate an expression string using xarray's native operations."""
9544+
try:
9545+
tree = ast.parse(expr, mode="eval")
9546+
except SyntaxError as e:
9547+
raise ValueError(f"Invalid expression syntax: {expr}") from e
9548+
9549+
# Transform logical operators for consistency with query().
9550+
# See LogicalOperatorTransformer docstring for details.
9551+
tree = LogicalOperatorTransformer().visit(tree)
9552+
ast.fix_missing_locations(tree)
9553+
9554+
validate_expression(tree)
9555+
9556+
# Build namespace: data variables, coordinates, modules, and safe builtins.
9557+
# Empty __builtins__ blocks dangerous functions like __import__, exec, open.
9558+
# Priority order (highest to lowest): data variables > coordinates > modules > builtins
9559+
# This ensures user data always wins when names collide with builtins.
9560+
import xarray as xr # Lazy import to avoid circular dependency
9561+
9562+
namespace: dict[str, Any] = dict(EVAL_BUILTINS)
9563+
namespace.update({"np": np, "pd": pd, "xr": xr})
9564+
namespace.update({str(name): self.coords[name] for name in self.coords})
9565+
namespace.update({str(name): self[name] for name in self.data_vars})
9566+
9567+
code = compile(tree, "<xarray.eval>", "eval")
9568+
return builtins.eval(code, {"__builtins__": {}}, namespace)
9569+
95369570
def eval(
95379571
self,
95389572
statement: str,
95399573
*,
9540-
parser: QueryParserOptions = "pandas",
9541-
) -> Self | T_DataArray:
9574+
parser: QueryParserOptions | Default = _default,
9575+
) -> Self | DataArray:
95429576
"""
95439577
Calculate an expression supplied as a string in the context of the dataset.
95449578
95459579
This is currently experimental; the API may change particularly around
95469580
assignments, which currently return a ``Dataset`` with the additional variable.
9547-
Currently only the ``python`` engine is supported, which has the same
9548-
performance as executing in python.
9581+
9582+
Logical operators (``and``, ``or``, ``not``) are automatically transformed
9583+
to bitwise operators (``&``, ``|``, ``~``) which work element-wise on arrays.
95499584
95509585
Parameters
95519586
----------
@@ -9555,7 +9590,11 @@ def eval(
95559590
Returns
95569591
-------
95579592
result : Dataset or DataArray, depending on whether ``statement`` contains an
9558-
assignment.
9593+
assignment.
9594+
9595+
Warning
9596+
-------
9597+
Like ``pd.eval()``, this method should not be used with untrusted input.
95599598
95609599
Examples
95619600
--------
@@ -9584,16 +9623,55 @@ def eval(
95849623
b (x) float64 40B 0.0 0.25 0.5 0.75 1.0
95859624
c (x) float64 40B 0.0 1.25 2.5 3.75 5.0
95869625
"""
9626+
if parser is not _default:
9627+
emit_user_level_warning(
9628+
"The 'parser' argument to Dataset.eval() is deprecated and will be "
9629+
"removed in a future version. Logical operators (and/or/not) are now "
9630+
"always transformed to bitwise operators (&/|/~) for array compatibility.",
9631+
FutureWarning,
9632+
)
95879633

9588-
return pd.eval( # type: ignore[return-value]
9589-
statement,
9590-
resolvers=[self],
9591-
target=self,
9592-
parser=parser,
9593-
# Because numexpr returns a numpy array, using that engine results in
9594-
# different behavior. We'd be very open to a contribution handling this.
9595-
engine="python",
9596-
)
9634+
statement = statement.strip()
9635+
9636+
# Check for assignment: "target = expr"
9637+
# Must handle compound operators like ==, !=, <=, >=
9638+
# Use ast to detect assignment properly
9639+
try:
9640+
tree = ast.parse(statement, mode="exec")
9641+
except SyntaxError as e:
9642+
raise ValueError(f"Invalid statement syntax: {statement}") from e
9643+
9644+
if len(tree.body) != 1:
9645+
raise ValueError("Only single statements are supported")
9646+
9647+
stmt = tree.body[0]
9648+
9649+
if isinstance(stmt, ast.Assign):
9650+
# Assignment: "c = a + b"
9651+
if len(stmt.targets) != 1:
9652+
raise ValueError("Only single assignment targets are supported")
9653+
target = stmt.targets[0]
9654+
if not isinstance(target, ast.Name):
9655+
raise ValueError(
9656+
f"Assignment target must be a simple name, got {type(target).__name__}"
9657+
)
9658+
target_name = target.id
9659+
9660+
# Get the expression source
9661+
expr_source = ast.unparse(stmt.value)
9662+
result: DataArray = self._eval_expression(expr_source)
9663+
return self.assign({target_name: result})
9664+
9665+
elif isinstance(stmt, ast.Expr):
9666+
# Expression: "a + b"
9667+
expr_source = ast.unparse(stmt.value)
9668+
return self._eval_expression(expr_source)
9669+
9670+
else:
9671+
raise ValueError(
9672+
f"Unsupported statement type: {type(stmt).__name__}. "
9673+
f"Only expressions and assignments are supported."
9674+
)
95979675

95989676
def query(
95999677
self,

xarray/core/eval.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Expression evaluation for Dataset.eval().
3+
4+
This module provides AST-based expression evaluation to support N-dimensional
5+
arrays (N > 2), which pd.eval() doesn't support. See GitHub issue #11062.
6+
7+
We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~',
8+
and chained comparisons) for consistency with query(), which still uses
9+
pd.eval(). We don't migrate query() to this implementation because:
10+
- query() typically works fine (expressions usually compare 1D coordinates)
11+
- pd.eval() with numexpr is faster and well-tested for query's use case
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import ast
17+
import builtins
18+
from typing import Any
19+
20+
# Base namespace for eval expressions.
21+
# We add common builtins back since we use an empty __builtins__ dict.
22+
EVAL_BUILTINS: dict[str, Any] = {
23+
# Numeric/aggregation functions
24+
"abs": abs,
25+
"min": min,
26+
"max": max,
27+
"round": round,
28+
"len": len,
29+
"sum": sum,
30+
"pow": pow,
31+
"any": any,
32+
"all": all,
33+
# Type constructors
34+
"int": int,
35+
"float": float,
36+
"bool": bool,
37+
"str": str,
38+
"list": list,
39+
"tuple": tuple,
40+
"dict": dict,
41+
"set": set,
42+
"slice": slice,
43+
# Iteration helpers
44+
"range": range,
45+
"zip": zip,
46+
"enumerate": enumerate,
47+
"map": builtins.map,
48+
"filter": filter,
49+
}
50+
51+
52+
class LogicalOperatorTransformer(ast.NodeTransformer):
53+
"""Transform operators for consistency with query().
54+
55+
query() uses pd.eval() which transforms these operators automatically.
56+
We replicate that behavior here so syntax that works in query() also
57+
works in eval().
58+
59+
Transformations:
60+
1. 'and'/'or'/'not' -> '&'/'|'/'~'
61+
2. 'a < b < c' -> '(a < b) & (b < c)'
62+
63+
These constructs fail on arrays in standard Python because they call
64+
__bool__(), which is ambiguous for multi-element arrays.
65+
"""
66+
67+
def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
68+
# Transform: a and b -> a & b, a or b -> a | b
69+
self.generic_visit(node)
70+
op: ast.BitAnd | ast.BitOr
71+
if isinstance(node.op, ast.And):
72+
op = ast.BitAnd()
73+
elif isinstance(node.op, ast.Or):
74+
op = ast.BitOr()
75+
else:
76+
return node
77+
78+
# BoolOp can have multiple values: a and b and c
79+
# Transform to chained BinOp: (a & b) & c
80+
result = node.values[0]
81+
for value in node.values[1:]:
82+
result = ast.BinOp(left=result, op=op, right=value)
83+
return ast.fix_missing_locations(result)
84+
85+
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
86+
# Transform: not a -> ~a
87+
self.generic_visit(node)
88+
if isinstance(node.op, ast.Not):
89+
return ast.fix_missing_locations(
90+
ast.UnaryOp(op=ast.Invert(), operand=node.operand)
91+
)
92+
return node
93+
94+
def visit_Compare(self, node: ast.Compare) -> ast.AST:
95+
# Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5)
96+
# Python's chained comparisons use short-circuit evaluation at runtime,
97+
# which calls __bool__ on intermediate results. This fails for arrays.
98+
# We transform to bitwise AND which works element-wise.
99+
self.generic_visit(node)
100+
101+
if len(node.ops) == 1:
102+
# Simple comparison, no transformation needed
103+
return node
104+
105+
# Build individual comparisons and chain with BitAnd
106+
# For: a < b < c < d
107+
# We need: (a < b) & (b < c) & (c < d)
108+
comparisons = []
109+
left = node.left
110+
for op, comparator in zip(node.ops, node.comparators, strict=True):
111+
comp = ast.Compare(left=left, ops=[op], comparators=[comparator])
112+
comparisons.append(comp)
113+
left = comparator
114+
115+
# Chain with BitAnd: (a < b) & (b < c) & ...
116+
result: ast.Compare | ast.BinOp = comparisons[0]
117+
for comp in comparisons[1:]:
118+
result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp)
119+
return ast.fix_missing_locations(result)
120+
121+
122+
def validate_expression(tree: ast.AST) -> None:
123+
"""Validate that an AST doesn't contain patterns we don't support.
124+
125+
These restrictions emulate pd.eval() behavior for consistency.
126+
"""
127+
for node in ast.walk(tree):
128+
# Block lambda expressions (pd.eval: "Only named functions are supported")
129+
if isinstance(node, ast.Lambda):
130+
raise ValueError(
131+
"Lambda expressions are not allowed in eval(). "
132+
"Use direct operations on data variables instead."
133+
)
134+
# Block private/dunder attributes (consistent with pd.eval restrictions)
135+
if isinstance(node, ast.Attribute) and node.attr.startswith("_"):
136+
raise ValueError(
137+
f"Access to private attributes is not allowed: '{node.attr}'"
138+
)

xarray/tests/test_dataset.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7673,23 +7673,6 @@ def test_query(self, backend, engine, parser) -> None:
76737673
# pytest tests — new tests should go here, rather than in the class.
76747674

76757675

7676-
@pytest.mark.parametrize("parser", ["pandas", "python"])
7677-
def test_eval(ds, parser) -> None:
7678-
"""Currently much more minimal testing that `query` above, and much of the setup
7679-
isn't used. But the risks are fairly low — `query` shares much of the code, and
7680-
the method is currently experimental."""
7681-
7682-
actual = ds.eval("z1 + 5", parser=parser)
7683-
expect = ds["z1"] + 5
7684-
assert_identical(expect, actual)
7685-
7686-
# check pandas query syntax is supported
7687-
if parser == "pandas":
7688-
actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser)
7689-
expect = (ds["z1"] > 5) & (ds["z2"] > 0)
7690-
assert_identical(expect, actual)
7691-
7692-
76937676
@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2])))
76947677
def test_isin(test_elements, backend) -> None:
76957678
expected = Dataset(

0 commit comments

Comments
 (0)