Skip to content

Commit 5a3dfb4

Browse files
authored
Ensure repartition does not trigger memory size computation during lowering (i.e. on the scheduler) (dask#11855)
1 parent 1221d34 commit 5a3dfb4

File tree

5 files changed

+87
-8
lines changed

5 files changed

+87
-8
lines changed

dask/_expr.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class Expr:
4545
_parameters: list[str] = []
4646
_defaults: dict[str, Any] = {}
4747
_instances: weakref.WeakValueDictionary[str, Expr] = weakref.WeakValueDictionary()
48+
_pickle_functools_cache: bool = True
4849

4950
operands: list
5051

@@ -146,14 +147,23 @@ def __dask_tokenize__(self):
146147

147148
@staticmethod
148149
def _reconstruct(*args):
149-
typ, *operands, token = args
150-
return typ(*operands, _determ_token=token)
150+
typ, *operands, token, cache = args
151+
inst = typ(*operands, _determ_token=token)
152+
for k, v in cache.items():
153+
inst.__dict__[k] = v
154+
return inst
151155

152156
def __reduce__(self):
153157
if dask.config.get("dask-expr-no-serialize", False):
154158
raise RuntimeError(f"Serializing a {type(self)} object")
159+
cache = {}
160+
if type(self)._pickle_functools_cache:
161+
for k, v in type(self).__dict__.items():
162+
if isinstance(v, functools.cached_property) and k in self.__dict__:
163+
cache[k] = getattr(self, k)
164+
155165
return Expr._reconstruct, tuple(
156-
[type(self), *self.operands, self.deterministic_token]
166+
[type(self), *self.operands, self.deterministic_token, cache]
157167
)
158168

159169
def _depth(self, cache=None):

dask/dataframe/dask_expr/_repartition.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,21 +415,26 @@ def _lower(self):
415415

416416

417417
class RepartitionSize(Repartition):
418+
418419
@functools.cached_property
419420
def _size(self):
420421
size = self.operand("partition_size")
421422
if isinstance(size, str):
422423
size = parse_bytes(size)
423424
return int(size)
424425

426+
@functools.cached_property
427+
def _mem_usage(self):
428+
return _get_mem_usages(self.frame)
429+
425430
@functools.cached_property
426431
def _nsplits(self):
427-
return 1 + _get_mem_usages(self.frame) // self._size
432+
return 1 + self._mem_usage // self._size
428433

429434
@functools.cached_property
430435
def _partition_boundaries(self):
431436
nsplits = self._nsplits
432-
mem_usages = _get_mem_usages(self.frame)
437+
mem_usages = self._mem_usage
433438

434439
if np.any(nsplits > 1):
435440
split_mem_usages = []
@@ -449,6 +454,11 @@ def _divisions(self):
449454
return (None,) * len(self._partition_boundaries)
450455
return (self.frame.divisions[i] for i in self._partition_boundaries)
451456

457+
def _lower(self):
458+
# populate cache
459+
self._mem_usage # noqa
460+
return super()._lower()
461+
452462
def _layer(self) -> dict:
453463
df = self.frame
454464
dsk: dict[tuple, Any] = {}

dask/dataframe/dask_expr/io/parquet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ def default_types_mapper(pyarrow_dtype):
716716

717717

718718
class ReadParquet(PartitionsFiltered, BlockwiseIO):
719+
_pickle_functools_cache = False
719720
_absorb_projections = True
720721
_filter_passthrough = False
721722

dask/dataframe/dask_expr/tests/test_collection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ def test_column_getattr(df):
13701370
def test_serialization(pdf, df):
13711371
before = pickle.dumps(df)
13721372

1373-
assert len(before) < 200 + len(pickle.dumps(pdf))
1373+
assert len(before) < 350 + len(pickle.dumps(pdf))
13741374

13751375
part = df.partitions[0].compute()
13761376
assert (
@@ -1380,8 +1380,6 @@ def test_serialization(pdf, df):
13801380

13811381
after = pickle.dumps(df)
13821382

1383-
assert before == after # caching doesn't affect serialization
1384-
13851383
assert pickle.loads(before)._name == pickle.loads(after)._name
13861384
assert_eq(pickle.loads(before), pickle.loads(after))
13871385

dask/tests/test_expr.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import functools
4+
import pickle
5+
36
import pytest
47

58
from dask._expr import Expr
@@ -28,3 +31,60 @@ def test_setattr2():
2831
assert e.bar == 3
2932
with pytest.raises(AttributeError):
3033
e.baz = 4
34+
35+
36+
class MyExprCachedProperty(Expr):
37+
called_cached_property = False
38+
_parameters = ["foo", "bar"]
39+
40+
@property
41+
def baz(self):
42+
return self.foo + self.bar
43+
44+
@functools.cached_property
45+
def cached_property(self):
46+
if MyExprCachedProperty.called_cached_property:
47+
raise RuntimeError("No!")
48+
MyExprCachedProperty.called_cached_property = True
49+
return self.foo + self.bar
50+
51+
52+
def test_pickle_cached_properties():
53+
pytest.importorskip("distributed")
54+
from distributed import Nanny
55+
from distributed.utils_test import gen_cluster
56+
57+
@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
58+
async def test(c, s, a):
59+
60+
expr = MyExprCachedProperty(foo=1, bar=2)
61+
for _ in range(10):
62+
assert expr.baz == 3
63+
assert expr.cached_property == 3
64+
65+
assert MyExprCachedProperty.called_cached_property is True
66+
67+
rt = pickle.loads(pickle.dumps(expr))
68+
assert rt.cached_property == 3
69+
assert MyExprCachedProperty.called_cached_property is True
70+
71+
# Expressions are singletons, i.e. this doesn't crash
72+
expr2 = MyExprCachedProperty(foo=1, bar=2)
73+
assert expr2.cached_property == 3
74+
75+
# But this does
76+
expr3 = MyExprCachedProperty(foo=1, bar=3)
77+
with pytest.raises(RuntimeError):
78+
expr3.cached_property
79+
80+
def f(expr):
81+
# We want the cache to be part of the pickle, i.e. this is a
82+
# different process such that the type is reset and the property can
83+
# be accessed without side effects
84+
assert MyExprCachedProperty.called_cached_property is False
85+
assert expr.cached_property == 3
86+
assert MyExprCachedProperty.called_cached_property is False
87+
88+
await c.submit(f, expr)
89+
90+
test()

0 commit comments

Comments
 (0)