Skip to content

Commit bd3a7d6

Browse files
authored
Expr setattr (dask#11836)
1 parent c85bae2 commit bd3a7d6

File tree

6 files changed

+77
-37
lines changed

6 files changed

+77
-37
lines changed

dask/_expr.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ def _depth(self, cache=None):
164164
cache[expr._name] = result[-1]
165165
return max(result)
166166

167+
def __setattr__(self, name: str, value: Any) -> None:
168+
if name in ["operands", "_determ_token"]:
169+
object.__setattr__(self, name, value)
170+
return
171+
try:
172+
params = object.__getattribute__(type(self), "_parameters")
173+
operands = object.__getattribute__(self, "operands")
174+
operands[params.index(name)] = value
175+
except ValueError:
176+
raise AttributeError(
177+
f"{type(self).__name__} object has no attribute {name}"
178+
)
179+
167180
def operand(self, key):
168181
# Access an operand unambiguously
169182
# (e.g. if the key is reserved by a method/property)

dask/array/_array_expr/_expr.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tlz import accumulate
1212

1313
from dask._expr import Expr
14-
from dask._task_spec import Task, TaskRef
14+
from dask._task_spec import List, Task, TaskRef
1515
from dask.array.chunk import getitem
1616
from dask.array.core import T_IntOrNaN, common_blockdim, unknown_chunk_message
1717
from dask.blockwise import broadcast_dimensions
@@ -20,7 +20,6 @@
2020

2121

2222
class ArrayExpr(Expr):
23-
_cached_keys = None
2423

2524
def _operands_for_repr(self):
2625
return []
@@ -75,25 +74,35 @@ def __len__(self):
7574
raise ValueError(msg)
7675
return int(sum(self.chunks[0]))
7776

78-
def __dask_keys__(self):
77+
@functools.cached_property
78+
def _cached_keys(self):
7979
out = self.lower_completely()
80-
if self._cached_keys is not None:
81-
return self._cached_keys
8280

8381
name, chunks, numblocks = out.name, out.chunks, out.numblocks
8482

8583
def keys(*args):
8684
if not chunks:
87-
return [(name,)]
85+
return List(TaskRef((name,)))
8886
ind = len(args)
8987
if ind + 1 == len(numblocks):
90-
result = [(name,) + args + (i,) for i in range(numblocks[ind])]
88+
result = List(
89+
*(TaskRef((name,) + args + (i,)) for i in range(numblocks[ind]))
90+
)
9191
else:
92-
result = [keys(*(args + (i,))) for i in range(numblocks[ind])]
92+
result = List(*(keys(*(args + (i,))) for i in range(numblocks[ind])))
9393
return result
9494

95-
self._cached_keys = result = keys()
96-
return result
95+
return keys()
96+
97+
def __dask_keys__(self):
98+
key_refs = self._cached_keys
99+
100+
def unwrap(task):
101+
if isinstance(task, List):
102+
return [unwrap(t) for t in task.args]
103+
return task.key
104+
105+
return unwrap(key_refs)
97106

98107
def __dask_tokenize__(self):
99108
return self._name

dask/dataframe/dask_expr/_expr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,7 @@ def known_divisions(self):
438438
@property
439439
def npartitions(self):
440440
if "npartitions" in self._parameters:
441-
idx = self._parameters.index("npartitions")
442-
return self.operands[idx]
441+
return self.operand("npartitions")
443442
else:
444443
return len(self.divisions) - 1
445444

dask/dataframe/dask_expr/io/io.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
398398
"pyarrow_strings_enabled",
399399
"_partitions",
400400
"_series",
401+
"_pd_length_stats",
401402
]
402403
_defaults = {
403404
"npartitions": None,
@@ -407,8 +408,9 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
407408
"_series": False,
408409
"chunksize": None,
409410
"pyarrow_strings_enabled": True,
411+
"_pd_length_stats": None,
410412
}
411-
_pd_length_stats = None
413+
_pd_length_stats: tuple | None
412414
_absorb_projections = True
413415

414416
@functools.cached_property
@@ -538,8 +540,14 @@ class FromPandasDivisions(FromPandas):
538540
"pyarrow_strings_enabled",
539541
"_partitions",
540542
"_series",
543+
"_pd_length_stats",
541544
]
542-
_defaults = {"columns": None, "_partitions": None, "_series": False}
545+
_defaults = {
546+
"columns": None,
547+
"_partitions": None,
548+
"_series": False,
549+
"_pd_length_stats": None,
550+
}
543551
sort = True
544552

545553
@functools.cached_property

dask/dataframe/dask_expr/io/parquet.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,6 @@ def default_types_mapper(pyarrow_dtype):
717717

718718

719719
class ReadParquet(PartitionsFiltered, BlockwiseIO):
720-
_pq_length_stats = None
721720
_absorb_projections = True
722721
_filter_passthrough = False
723722

@@ -1074,9 +1073,7 @@ def _dataset_info(self):
10741073

10751074
dataset_info["schema"] = dataset.schema
10761075
dataset_info["base_meta"] = dataset.schema.empty_table().to_pandas()
1077-
self.operands[type(self)._parameters.index("_dataset_info_cache")] = (
1078-
dataset_info
1079-
)
1076+
self._dataset_info_cache = dataset_info
10801077
return dataset_info
10811078

10821079
@cached_property
@@ -1279,6 +1276,7 @@ class ReadParquetFSSpec(ReadParquet):
12791276
"_partitions",
12801277
"_series",
12811278
"_dataset_info_cache",
1279+
"_pq_length_stats",
12821280
]
12831281
_defaults = {
12841282
"columns": None,
@@ -1299,6 +1297,7 @@ class ReadParquetFSSpec(ReadParquet):
12991297
"_partitions": None,
13001298
"_series": False,
13011299
"_dataset_info_cache": None,
1300+
"_pq_length_stats": None,
13021301
}
13031302

13041303
@property
@@ -1410,9 +1409,7 @@ def _dataset_info(self):
14101409
dataset_info["all_columns"] = all_columns
14111410
dataset_info["calculate_divisions"] = self.calculate_divisions
14121411

1413-
self.operands[type(self)._parameters.index("_dataset_info_cache")] = (
1414-
dataset_info
1415-
)
1412+
self._dataset_info_cache - dataset_info
14161413
return dataset_info
14171414

14181415
def _filtered_task(self, name: Key, index: int) -> Task:
@@ -1480,29 +1477,27 @@ def _get_lengths(self) -> tuple | None:
14801477
"""Return known partition lengths using parquet statistics"""
14811478
if not self.filters:
14821479
self._update_length_statistics()
1483-
return tuple( # type: ignore
1480+
return tuple(
14841481
length
1485-
for i, length in enumerate(self._pq_length_stats) # type: ignore
1482+
for i, length in enumerate(self._pq_length_stats)
14861483
if not self._filtered or i in self._partitions
14871484
)
14881485
return None
14891486

1490-
def _update_length_statistics(self):
1487+
@cached_property
1488+
def _pq_length_stats(self):
14911489
"""Ensure that partition-length statistics are up to date"""
14921490

1493-
if not self._pq_length_stats:
1494-
if self._plan["statistics"]:
1495-
# Already have statistics from original API call
1496-
self._pq_length_stats = tuple(
1497-
stat["num-rows"]
1498-
for i, stat in enumerate(self._plan["statistics"])
1499-
if not self._filtered or i in self._partitions
1500-
)
1501-
else:
1502-
# Need to go back and collect statistics
1503-
self._pq_length_stats = tuple(
1504-
stat["num-rows"] for stat in _collect_pq_statistics(self)
1505-
)
1491+
if self._plan["statistics"]:
1492+
# Already have statistics from original API call
1493+
return tuple(
1494+
stat["num-rows"]
1495+
for i, stat in enumerate(self._plan["statistics"])
1496+
if not self._filtered or i in self._partitions
1497+
)
1498+
else:
1499+
# Need to go back and collect statistics
1500+
return tuple(stat["num-rows"] for stat in _collect_pq_statistics(self))
15061501

15071502

15081503
#

dask/tests/test_expr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from dask._expr import Expr
6+
7+
8+
def test_setattr():
9+
class MyExpr(Expr):
10+
_parameters = ["foo", "bar"]
11+
12+
e = MyExpr(foo=1, bar=2)
13+
e.bar = 3
14+
assert e.bar == 3
15+
with pytest.raises(AttributeError):
16+
e.baz = 4

0 commit comments

Comments
 (0)