Skip to content

Commit b2a4a21

Browse files
authored
Wrap HLGs in an Expr to avoid Client side materialization (dask#11736)
1 parent 0b52650 commit b2a4a21

29 files changed

+630
-308
lines changed

dask/_expr.py

Lines changed: 325 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools
44
import os
5+
import warnings
56
import weakref
67
from collections import defaultdict
78
from collections.abc import Generator
@@ -11,9 +12,10 @@
1112

1213
import dask
1314
from dask._task_spec import Task
15+
from dask.core import reverse_dict
1416
from dask.tokenize import _tokenize_deterministic
1517
from dask.typing import Key
16-
from dask.utils import funcname, import_required
18+
from dask.utils import ensure_dict, funcname, import_required
1719

1820
if TYPE_CHECKING:
1921
# TODO import from typing (requires Python >=3.10)
@@ -73,8 +75,13 @@ def _tune_down(self):
7375
def _tune_up(self, parent):
7476
return None
7577

78+
def finalize_compute(self):
79+
return self
80+
7681
def _operands_for_repr(self):
77-
raise NotImplementedError("Subclasses should implement this method")
82+
return [
83+
f"{param}={repr(op)}" for param, op in zip(self._parameters, self.operands)
84+
]
7885

7986
def __str__(self):
8087
s = ", ".join(self._operands_for_repr())
@@ -99,7 +106,7 @@ def _tree_repr_argument_construction(self, i, op, header):
99106
return header
100107

101108
def _tree_repr_lines(self, indent=0, recursive=True):
102-
raise NotImplementedError("Subclasses should implement this method")
109+
return " " * indent + repr(self)
103110

104111
def tree_repr(self):
105112
return os.linesep.join(self._tree_repr_lines())
@@ -140,7 +147,7 @@ def __reduce__(self):
140147
if dask.config.get("dask-expr-no-serialize", False):
141148
raise RuntimeError(f"Serializing a {type(self)} object")
142149
return Expr._reconstruct, tuple(
143-
[type(self)] + self.operands + [self.deterministic_token]
150+
[type(self), *self.operands, self.deterministic_token]
144151
)
145152

146153
def _depth(self, cache=None):
@@ -498,6 +505,9 @@ def _name(self) -> str:
498505
def _meta(self):
499506
raise NotImplementedError()
500507

508+
def __dask_annotations__(self):
509+
return {}
510+
501511
def __dask_graph__(self):
502512
"""Traverse expression tree, collect layers"""
503513
stack = [self]
@@ -862,3 +872,314 @@ def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
862872
return expr
863873

864874
raise ValueError(f"Stage {stage!r} not supported.")
875+
876+
877+
class LLGExpr(Expr):
878+
"""Low Level Graph Expression"""
879+
880+
_parameters = ["dsk"]
881+
882+
def __dask_keys__(self):
883+
return list(self.operand("dsk"))
884+
885+
def __dask_tokenize__(self):
886+
return str(id(self))
887+
888+
def _layer(self) -> dict:
889+
return ensure_dict(self.operand("dsk"))
890+
891+
892+
class HLGExpr(Expr):
893+
_parameters = [
894+
"dsk",
895+
"low_level_optimizer",
896+
"output_keys",
897+
"postcompute",
898+
"_cached_optimized",
899+
]
900+
_defaults = {
901+
"low_level_optimizer": None,
902+
"output_keys": None,
903+
"postcompute": None,
904+
"_cached_optimized": None,
905+
}
906+
907+
@staticmethod
908+
def from_collection(collection, optimize_graph=True):
909+
from dask.highlevelgraph import HighLevelGraph
910+
911+
if hasattr(collection, "dask"):
912+
dsk = collection.dask.copy()
913+
else:
914+
dsk = collection.__dask_graph__()
915+
916+
# Delayed objects still ship with low level graphs as `dask` when going
917+
# through optimize / persist
918+
if not isinstance(dsk, HighLevelGraph):
919+
920+
dsk = HighLevelGraph.from_collections(
921+
str(id(collection)), dsk, dependencies=()
922+
)
923+
if optimize_graph and not hasattr(collection, "__dask_optimize__"):
924+
warnings.warn(
925+
f"Collection {type(collection)} does not define a "
926+
"`__dask_optimize__` method. In the future this will raise. "
927+
"If no optimization is desired, please set this to `None`.",
928+
PendingDeprecationWarning,
929+
)
930+
low_level_optimizer = None
931+
else:
932+
low_level_optimizer = (
933+
collection.__dask_optimize__ if optimize_graph else None
934+
)
935+
return HLGExpr(
936+
dsk=dsk,
937+
low_level_optimizer=low_level_optimizer,
938+
output_keys=collection.__dask_keys__(),
939+
postcompute=collection.__dask_postcompute__,
940+
)
941+
942+
def finalize_compute(self):
943+
return HLGFinalizeCompute(self)
944+
945+
def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
946+
# optimization has to be called (and cached) since blockwise fusion can
947+
# alter annotations
948+
# see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
949+
dsk = self._optimized_dsk()
950+
if isinstance(dsk, dict):
951+
dsk = self.dsk
952+
annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
953+
for layer in dsk.layers.values():
954+
if layer.annotations:
955+
annot = layer.annotations
956+
for annot_type, value in annot.items():
957+
annotations_by_type[annot_type].update(
958+
{k: (value(k) if callable(value) else value) for k in layer}
959+
)
960+
return dict(annotations_by_type)
961+
962+
def __dask_keys__(self):
963+
if keys := self.operand("output_keys"):
964+
return keys
965+
dsk = self.operand("dsk")
966+
# Note: This will materialize
967+
dependencies = dsk.get_all_dependencies()
968+
dependents = reverse_dict(dependencies)
969+
keys = [d for d in dependents if not dependents[d] and d in dsk]
970+
self.output_keys = keys
971+
return keys
972+
973+
def __dask_tokenize__(self):
974+
# There is currently not way to hash a HighLevelGraph fast and reliably.
975+
# It is important for dask-expr for this to not be duplicated so we'll
976+
# just use the ID.
977+
return str(id(self))
978+
979+
def _optimized_dsk(self):
980+
if self._cached_optimized:
981+
return self._cached_optimized
982+
keys = self.output_keys
983+
optimizer = self.low_level_optimizer
984+
if keys is None and optimizer is not None:
985+
keys = self.__dask_keys__()
986+
dsk = self.dsk
987+
if (optimizer := self.low_level_optimizer) is not None:
988+
dsk = optimizer(dsk, keys)
989+
self._cached_optimized = dsk
990+
return dsk
991+
992+
def _layer(self) -> dict:
993+
dsk = self._optimized_dsk()
994+
return ensure_dict(dsk)
995+
996+
997+
class _HLGExprSequence(Expr):
998+
999+
def __getitem__(self, other):
1000+
return self.operands[other]
1001+
1002+
def _operands_for_repr(self):
1003+
return [
1004+
f"name={self.operand('name')!r}",
1005+
f"dsk={self.operand('dsk')!r}",
1006+
]
1007+
1008+
def _tree_repr_lines(self, indent=0, recursive=True):
1009+
return self._operands_for_repr()
1010+
1011+
def finalize_compute(self):
1012+
return HLGFinalizeCompute(self)
1013+
1014+
def __dask_graph__(self):
1015+
# This class has to override this and not just _layer to ensure the HLGs
1016+
# are not optimized individually
1017+
from dask.highlevelgraph import HighLevelGraph
1018+
1019+
groups = toolz.groupby(
1020+
lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1021+
self.operands,
1022+
)
1023+
outer_graphs = []
1024+
for optimizer, group in groups.items():
1025+
graphs = []
1026+
for hlg in group:
1027+
if isinstance(hlg, HLGExpr):
1028+
graphs.append(hlg.dsk)
1029+
else:
1030+
# FinalizeCompute
1031+
graphs.append(hlg._layer())
1032+
1033+
dsk = HighLevelGraph.merge(*graphs)
1034+
keys = [v.__dask_keys__() for v in group]
1035+
if optimizer is not None:
1036+
dsk = optimizer(dsk, keys)
1037+
outer_graphs.append(dsk)
1038+
1039+
dsk = HighLevelGraph.merge(*outer_graphs)
1040+
return ensure_dict(dsk)
1041+
1042+
_layer = __dask_graph__
1043+
1044+
def __dask_annotations__(self):
1045+
annotations_by_type = {}
1046+
for hlg in self.operands:
1047+
for k, v in hlg.__dask_annotations__().items():
1048+
annotations_by_type.setdefault(k, {}).update(v)
1049+
return annotations_by_type
1050+
1051+
def __dask_keys__(self) -> list:
1052+
all_keys = []
1053+
for op in self.operands:
1054+
all_keys.append(op.__dask_keys__())
1055+
return all_keys
1056+
1057+
1058+
class _ExprSequence(Expr):
1059+
"""A sequence of expressions
1060+
1061+
This is used to be able to optimize multiple collections combined, e.g. when
1062+
being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1063+
"""
1064+
1065+
def __getitem__(self, other):
1066+
return self.operands[other]
1067+
1068+
def _layer(self) -> dict:
1069+
return toolz.merge(op._layer() for op in self.operands)
1070+
1071+
def __dask_keys__(self) -> list:
1072+
all_keys = []
1073+
for op in self.operands:
1074+
all_keys.append(op.__dask_keys__())
1075+
return all_keys
1076+
1077+
def finalize_compute(self):
1078+
return _ExprSequence(
1079+
*(op.finalize_compute() for op in self.operands),
1080+
)
1081+
1082+
def __dask_annotations__(self):
1083+
annotations_by_type = {}
1084+
for op in self.operands:
1085+
for k, v in op.__dask_annotations__().items():
1086+
annotations_by_type.setdefault(k, {}).update(v)
1087+
return annotations_by_type
1088+
1089+
def __len__(self):
1090+
return len(self.operands)
1091+
1092+
def __iter__(self):
1093+
return iter(self.operands)
1094+
1095+
def _simplify_down(self):
1096+
from dask.highlevelgraph import HighLevelGraph
1097+
1098+
issue_warning = False
1099+
hlgs = []
1100+
for op in self.operands:
1101+
if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1102+
hlgs.append(op)
1103+
elif isinstance(op, dict):
1104+
hlgs.append(
1105+
HLGExpr(
1106+
dsk=HighLevelGraph.from_collections(
1107+
str(id(op)), op, dependencies=()
1108+
)
1109+
)
1110+
)
1111+
elif hlgs:
1112+
issue_warning = True
1113+
opt = op.optimize()
1114+
hlgs.append(
1115+
HLGExpr(
1116+
dsk=HighLevelGraph.from_collections(
1117+
opt._name, opt.__dask_graph__(), dependencies=()
1118+
)
1119+
)
1120+
)
1121+
if issue_warning:
1122+
warnings.warn(
1123+
"Computing mixed collections that are backed by "
1124+
"HighlevelGraphs/dicts and Expressions. "
1125+
"This forces Expressions to be materialized. "
1126+
"It is recommended to use only one type and separate the dask."
1127+
"compute calls if necessary.",
1128+
UserWarning,
1129+
)
1130+
if not hlgs:
1131+
return None
1132+
return _HLGExprSequence(*hlgs)
1133+
1134+
1135+
class FinalizeCompute(Expr):
1136+
_parameters = ["expr"]
1137+
1138+
def _simplify_down(self):
1139+
return self.expr.finalize_compute()
1140+
1141+
1142+
def _convert_dask_keys(keys):
1143+
from dask._task_spec import List, TaskRef
1144+
1145+
assert isinstance(keys, list)
1146+
new_keys = []
1147+
for key in keys:
1148+
if isinstance(key, list):
1149+
new_keys.append(_convert_dask_keys(key))
1150+
else:
1151+
new_keys.append(TaskRef(key))
1152+
return List(*new_keys)
1153+
1154+
1155+
class HLGFinalizeCompute(Expr):
1156+
_parameters = ["dsk"]
1157+
1158+
def __dask_annotations__(self):
1159+
return self.dsk.__dask_annotations__()
1160+
1161+
def _simplify_down(self):
1162+
from dask.delayed import Delayed
1163+
1164+
# Skip finalization for Delayed
1165+
if self.dsk.postcompute() == Delayed.__dask_postcompute__(self.dsk):
1166+
return self.dsk
1167+
return self
1168+
1169+
@property
1170+
def _name(self):
1171+
return f"finalize-{self.deterministic_token}"
1172+
1173+
def _layer(self) -> dict:
1174+
expr = self.operand("dsk")
1175+
dsk = expr._layer().copy()
1176+
1177+
func, extra_args = expr.postcompute()
1178+
keys = expr.__dask_keys__()
1179+
1180+
t = Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1181+
dsk[t.key] = t
1182+
return dsk
1183+
1184+
def __dask_keys__(self):
1185+
return [self._name]

dask/array/_array_expr/_collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __dask_postpersist__(self):
5151
return from_graph, (
5252
state._meta,
5353
state.chunks,
54+
# FIXME: This is using keys of the unoptimized graph
5455
list(flatten(state.__dask_keys__())),
5556
key_split(state._name),
5657
)

0 commit comments

Comments
 (0)