Skip to content

Commit aa36d37

Browse files
authored
Ensure HLGExpr tokenize uniquely (dask#11849)
1 parent e0877d0 commit aa36d37

File tree

8 files changed

+45
-28
lines changed

8 files changed

+45
-28
lines changed

dask/_expr.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,13 @@ def __hash__(self):
136136
return hash(self._name)
137137

138138
def __dask_tokenize__(self):
139-
return self._name
139+
if not self._determ_token:
140+
# If the subclass does not implement a __dask_tokenize__ we'll want
141+
# to tokenize all operands.
142+
# Note how this differs to the implementation of
143+
# Expr.deterministic_token
144+
self._determ_token = _tokenize_deterministic(type(self), *self.operands)
145+
return self._determ_token
140146

141147
@staticmethod
142148
def _reconstruct(*args):
@@ -494,7 +500,9 @@ def _funcname(self) -> str:
494500
@property
495501
def deterministic_token(self):
496502
if not self._determ_token:
497-
self._determ_token = _tokenize_deterministic(*self.operands)
503+
# Just tokenize self to fall back on __dask_tokenize__
504+
# Note how this differs to the implementation of __dask_tokenize__
505+
self._determ_token = self.__dask_tokenize__()
498506
return self._determ_token
499507

500508
@functools.cached_property
@@ -1074,6 +1082,11 @@ def __dask_keys__(self) -> list:
10741082
all_keys.append(op.__dask_keys__())
10751083
return all_keys
10761084

1085+
def __repr__(self):
1086+
return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1087+
1088+
__str__ = __repr__
1089+
10771090
def finalize_compute(self):
10781091
return _ExprSequence(
10791092
*(op.finalize_compute() for op in self.operands),

dask/array/_array_expr/_blockwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def chunks(self):
116116
def dtype(self):
117117
return self.operand("dtype")
118118

119-
@property
120-
def deterministic_token(self):
119+
def __dask_tokenize__(self):
121120
if not self._determ_token:
122121
# TODO: Is there an actual need to overwrite this?
123122
self._determ_token = _tokenize_deterministic(

dask/array/_array_expr/_expr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ def unwrap(task):
104104

105105
return unwrap(key_refs)
106106

107-
def __dask_tokenize__(self):
108-
return self._name
109-
110107
def __hash__(self):
111108
return hash(self._name)
112109

dask/array/_array_expr/_reductions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ class PartialReduce(ArrayExpr):
276276
"reduced_meta": None,
277277
}
278278

279-
@property
280-
def deterministic_token(self):
279+
def __dask_tokenize__(self):
281280
if not self._determ_token:
282281
# TODO: Is there an actual need to overwrite this?
283282
self._determ_token = _tokenize_deterministic(

dask/dataframe/dask_expr/_expr.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
raise_on_meta_error,
5050
valid_divisions,
5151
)
52-
from dask.tokenize import normalize_token
5352
from dask.typing import Key, no_default
5453
from dask.utils import (
5554
M,
@@ -3112,11 +3111,6 @@ def ndim(self):
31123111
return 0
31133112

31143113

3115-
@normalize_token.register(Expr)
3116-
def normalize_expression(expr):
3117-
return expr._name
3118-
3119-
31203114
def is_broadcastable(dfs, s):
31213115
"""
31223116
This Series is broadcastable against another dataframe in the sequence

dask/dataframe/dask_expr/io/parquet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,7 @@ def columns(self):
776776
def _funcname(self):
777777
return "read_parquet"
778778

779-
@property
780-
def deterministic_token(self):
779+
def __dask_tokenize__(self):
781780
if not self._determ_token:
782781
# TODO: Is there an actual need to overwrite this?
783782
self._determ_token = _tokenize_deterministic(

dask/tests/test_base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from dask.base import (
2020
DaskMethodsMixin,
2121
clone_key,
22-
collections_to_expr,
2322
compute,
2423
compute_as_if_collection,
2524
get_collection_names,
@@ -931,16 +930,6 @@ def test_num_workers_config(scheduler):
931930
assert len(workers) == num_workers
932931

933932

934-
def test_optimizations_ctd():
935-
pytest.importorskip("numpy")
936-
da = pytest.importorskip("dask.array")
937-
x = da.arange(2, chunks=1)[:1]
938-
dsk1 = collections_to_expr([x])
939-
with dask.config.set({"optimizations": [lambda dsk, keys: dsk]}):
940-
dsk2 = collections_to_expr([x])
941-
assert dsk1 == dsk2
942-
943-
944933
def test_clone_key():
945934
for key, seed in [("x", 123), (("x", 1), 456), (("sum-1-2-3", h1, 1), 123)]:
946935
validate_key(clone_key(key, seed))

dask/tests/test_hlgexpr.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
import pickle
4+
5+
from dask._expr import HLGExpr
6+
from dask.tokenize import tokenize
7+
8+
9+
def test_tokenize():
10+
# Ensure tokens are different for different high-level graphs The current
11+
# implementation actually ensures that no HLGExpr are tokenizing equally.
12+
# Technically, we do not need such a strong guarantee. but tokenizing a full
13+
# HLG reliably is tricky and we do not require the reproducibility for
14+
# HLGExpr since they do not undergo the same kind of optimization as the
15+
# rest of the graph.
16+
from dask.highlevelgraph import HighLevelGraph
17+
18+
dsk = HighLevelGraph.from_collections("x", {"foo": None})
19+
dsk2 = HighLevelGraph.from_collections("x", {"bar": None})
20+
dsk3 = HighLevelGraph.from_collections("y", {"foo": None})
21+
assert tokenize(HLGExpr(dsk)) != tokenize(HLGExpr(dsk2))
22+
assert tokenize(HLGExpr(dsk)) != tokenize(HLGExpr(dsk3))
23+
assert tokenize(HLGExpr(dsk2)) != tokenize(HLGExpr(dsk3))
24+
25+
# Roundtrip preserves the tokens
26+
for expr in [HLGExpr(dsk), HLGExpr(dsk2), HLGExpr(dsk3)]:
27+
assert tokenize(pickle.loads(pickle.dumps(expr))) == tokenize(expr)

0 commit comments

Comments
 (0)