2
2
3
3
import functools
4
4
import os
5
+ import warnings
5
6
import weakref
6
7
from collections import defaultdict
7
8
from collections .abc import Generator
11
12
12
13
import dask
13
14
from dask ._task_spec import Task
15
+ from dask .core import reverse_dict
14
16
from dask .tokenize import _tokenize_deterministic
15
17
from dask .typing import Key
16
- from dask .utils import funcname , import_required
18
+ from dask .utils import ensure_dict , funcname , import_required
17
19
18
20
if TYPE_CHECKING :
19
21
# TODO import from typing (requires Python >=3.10)
@@ -73,8 +75,13 @@ def _tune_down(self):
73
75
def _tune_up (self , parent ):
74
76
return None
75
77
78
+ def finalize_compute (self ):
79
+ return self
80
+
76
81
def _operands_for_repr (self ):
77
- raise NotImplementedError ("Subclasses should implement this method" )
82
+ return [
83
+ f"{ param } ={ repr (op )} " for param , op in zip (self ._parameters , self .operands )
84
+ ]
78
85
79
86
def __str__ (self ):
80
87
s = ", " .join (self ._operands_for_repr ())
@@ -99,7 +106,7 @@ def _tree_repr_argument_construction(self, i, op, header):
99
106
return header
100
107
101
108
def _tree_repr_lines (self , indent = 0 , recursive = True ):
102
- raise NotImplementedError ( "Subclasses should implement this method" )
109
+ return " " * indent + repr ( self )
103
110
104
111
def tree_repr (self ):
105
112
return os .linesep .join (self ._tree_repr_lines ())
@@ -140,7 +147,7 @@ def __reduce__(self):
140
147
if dask .config .get ("dask-expr-no-serialize" , False ):
141
148
raise RuntimeError (f"Serializing a { type (self )} object" )
142
149
return Expr ._reconstruct , tuple (
143
- [type (self )] + self .operands + [ self .deterministic_token ]
150
+ [type (self ), * self .operands , self .deterministic_token ]
144
151
)
145
152
146
153
def _depth (self , cache = None ):
@@ -498,6 +505,9 @@ def _name(self) -> str:
498
505
def _meta (self ):
499
506
raise NotImplementedError ()
500
507
508
+ def __dask_annotations__ (self ):
509
+ return {}
510
+
501
511
def __dask_graph__ (self ):
502
512
"""Traverse expression tree, collect layers"""
503
513
stack = [self ]
@@ -862,3 +872,314 @@ def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
862
872
return expr
863
873
864
874
raise ValueError (f"Stage { stage !r} not supported." )
875
+
876
+
877
+ class LLGExpr (Expr ):
878
+ """Low Level Graph Expression"""
879
+
880
+ _parameters = ["dsk" ]
881
+
882
+ def __dask_keys__ (self ):
883
+ return list (self .operand ("dsk" ))
884
+
885
+ def __dask_tokenize__ (self ):
886
+ return str (id (self ))
887
+
888
+ def _layer (self ) -> dict :
889
+ return ensure_dict (self .operand ("dsk" ))
890
+
891
+
892
+ class HLGExpr (Expr ):
893
+ _parameters = [
894
+ "dsk" ,
895
+ "low_level_optimizer" ,
896
+ "output_keys" ,
897
+ "postcompute" ,
898
+ "_cached_optimized" ,
899
+ ]
900
+ _defaults = {
901
+ "low_level_optimizer" : None ,
902
+ "output_keys" : None ,
903
+ "postcompute" : None ,
904
+ "_cached_optimized" : None ,
905
+ }
906
+
907
+ @staticmethod
908
+ def from_collection (collection , optimize_graph = True ):
909
+ from dask .highlevelgraph import HighLevelGraph
910
+
911
+ if hasattr (collection , "dask" ):
912
+ dsk = collection .dask .copy ()
913
+ else :
914
+ dsk = collection .__dask_graph__ ()
915
+
916
+ # Delayed objects still ship with low level graphs as `dask` when going
917
+ # through optimize / persist
918
+ if not isinstance (dsk , HighLevelGraph ):
919
+
920
+ dsk = HighLevelGraph .from_collections (
921
+ str (id (collection )), dsk , dependencies = ()
922
+ )
923
+ if optimize_graph and not hasattr (collection , "__dask_optimize__" ):
924
+ warnings .warn (
925
+ f"Collection { type (collection )} does not define a "
926
+ "`__dask_optimize__` method. In the future this will raise. "
927
+ "If no optimization is desired, please set this to `None`." ,
928
+ PendingDeprecationWarning ,
929
+ )
930
+ low_level_optimizer = None
931
+ else :
932
+ low_level_optimizer = (
933
+ collection .__dask_optimize__ if optimize_graph else None
934
+ )
935
+ return HLGExpr (
936
+ dsk = dsk ,
937
+ low_level_optimizer = low_level_optimizer ,
938
+ output_keys = collection .__dask_keys__ (),
939
+ postcompute = collection .__dask_postcompute__ ,
940
+ )
941
+
942
+ def finalize_compute (self ):
943
+ return HLGFinalizeCompute (self )
944
+
945
+ def __dask_annotations__ (self ) -> dict [str , dict [Key , object ]]:
946
+ # optimization has to be called (and cached) since blockwise fusion can
947
+ # alter annotations
948
+ # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
949
+ dsk = self ._optimized_dsk ()
950
+ if isinstance (dsk , dict ):
951
+ dsk = self .dsk
952
+ annotations_by_type : defaultdict [str , dict [Key , object ]] = defaultdict (dict )
953
+ for layer in dsk .layers .values ():
954
+ if layer .annotations :
955
+ annot = layer .annotations
956
+ for annot_type , value in annot .items ():
957
+ annotations_by_type [annot_type ].update (
958
+ {k : (value (k ) if callable (value ) else value ) for k in layer }
959
+ )
960
+ return dict (annotations_by_type )
961
+
962
+ def __dask_keys__ (self ):
963
+ if keys := self .operand ("output_keys" ):
964
+ return keys
965
+ dsk = self .operand ("dsk" )
966
+ # Note: This will materialize
967
+ dependencies = dsk .get_all_dependencies ()
968
+ dependents = reverse_dict (dependencies )
969
+ keys = [d for d in dependents if not dependents [d ] and d in dsk ]
970
+ self .output_keys = keys
971
+ return keys
972
+
973
+ def __dask_tokenize__ (self ):
974
+ # There is currently not way to hash a HighLevelGraph fast and reliably.
975
+ # It is important for dask-expr for this to not be duplicated so we'll
976
+ # just use the ID.
977
+ return str (id (self ))
978
+
979
+ def _optimized_dsk (self ):
980
+ if self ._cached_optimized :
981
+ return self ._cached_optimized
982
+ keys = self .output_keys
983
+ optimizer = self .low_level_optimizer
984
+ if keys is None and optimizer is not None :
985
+ keys = self .__dask_keys__ ()
986
+ dsk = self .dsk
987
+ if (optimizer := self .low_level_optimizer ) is not None :
988
+ dsk = optimizer (dsk , keys )
989
+ self ._cached_optimized = dsk
990
+ return dsk
991
+
992
+ def _layer (self ) -> dict :
993
+ dsk = self ._optimized_dsk ()
994
+ return ensure_dict (dsk )
995
+
996
+
997
+ class _HLGExprSequence (Expr ):
998
+
999
+ def __getitem__ (self , other ):
1000
+ return self .operands [other ]
1001
+
1002
+ def _operands_for_repr (self ):
1003
+ return [
1004
+ f"name={ self .operand ('name' )!r} " ,
1005
+ f"dsk={ self .operand ('dsk' )!r} " ,
1006
+ ]
1007
+
1008
+ def _tree_repr_lines (self , indent = 0 , recursive = True ):
1009
+ return self ._operands_for_repr ()
1010
+
1011
+ def finalize_compute (self ):
1012
+ return HLGFinalizeCompute (self )
1013
+
1014
+ def __dask_graph__ (self ):
1015
+ # This class has to override this and not just _layer to ensure the HLGs
1016
+ # are not optimized individually
1017
+ from dask .highlevelgraph import HighLevelGraph
1018
+
1019
+ groups = toolz .groupby (
1020
+ lambda x : x .low_level_optimizer if isinstance (x , HLGExpr ) else None ,
1021
+ self .operands ,
1022
+ )
1023
+ outer_graphs = []
1024
+ for optimizer , group in groups .items ():
1025
+ graphs = []
1026
+ for hlg in group :
1027
+ if isinstance (hlg , HLGExpr ):
1028
+ graphs .append (hlg .dsk )
1029
+ else :
1030
+ # FinalizeCompute
1031
+ graphs .append (hlg ._layer ())
1032
+
1033
+ dsk = HighLevelGraph .merge (* graphs )
1034
+ keys = [v .__dask_keys__ () for v in group ]
1035
+ if optimizer is not None :
1036
+ dsk = optimizer (dsk , keys )
1037
+ outer_graphs .append (dsk )
1038
+
1039
+ dsk = HighLevelGraph .merge (* outer_graphs )
1040
+ return ensure_dict (dsk )
1041
+
1042
+ _layer = __dask_graph__
1043
+
1044
+ def __dask_annotations__ (self ):
1045
+ annotations_by_type = {}
1046
+ for hlg in self .operands :
1047
+ for k , v in hlg .__dask_annotations__ ().items ():
1048
+ annotations_by_type .setdefault (k , {}).update (v )
1049
+ return annotations_by_type
1050
+
1051
+ def __dask_keys__ (self ) -> list :
1052
+ all_keys = []
1053
+ for op in self .operands :
1054
+ all_keys .append (op .__dask_keys__ ())
1055
+ return all_keys
1056
+
1057
+
1058
+ class _ExprSequence (Expr ):
1059
+ """A sequence of expressions
1060
+
1061
+ This is used to be able to optimize multiple collections combined, e.g. when
1062
+ being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1063
+ """
1064
+
1065
+ def __getitem__ (self , other ):
1066
+ return self .operands [other ]
1067
+
1068
+ def _layer (self ) -> dict :
1069
+ return toolz .merge (op ._layer () for op in self .operands )
1070
+
1071
+ def __dask_keys__ (self ) -> list :
1072
+ all_keys = []
1073
+ for op in self .operands :
1074
+ all_keys .append (op .__dask_keys__ ())
1075
+ return all_keys
1076
+
1077
+ def finalize_compute (self ):
1078
+ return _ExprSequence (
1079
+ * (op .finalize_compute () for op in self .operands ),
1080
+ )
1081
+
1082
+ def __dask_annotations__ (self ):
1083
+ annotations_by_type = {}
1084
+ for op in self .operands :
1085
+ for k , v in op .__dask_annotations__ ().items ():
1086
+ annotations_by_type .setdefault (k , {}).update (v )
1087
+ return annotations_by_type
1088
+
1089
+ def __len__ (self ):
1090
+ return len (self .operands )
1091
+
1092
+ def __iter__ (self ):
1093
+ return iter (self .operands )
1094
+
1095
+ def _simplify_down (self ):
1096
+ from dask .highlevelgraph import HighLevelGraph
1097
+
1098
+ issue_warning = False
1099
+ hlgs = []
1100
+ for op in self .operands :
1101
+ if isinstance (op , (HLGExpr , HLGFinalizeCompute )):
1102
+ hlgs .append (op )
1103
+ elif isinstance (op , dict ):
1104
+ hlgs .append (
1105
+ HLGExpr (
1106
+ dsk = HighLevelGraph .from_collections (
1107
+ str (id (op )), op , dependencies = ()
1108
+ )
1109
+ )
1110
+ )
1111
+ elif hlgs :
1112
+ issue_warning = True
1113
+ opt = op .optimize ()
1114
+ hlgs .append (
1115
+ HLGExpr (
1116
+ dsk = HighLevelGraph .from_collections (
1117
+ opt ._name , opt .__dask_graph__ (), dependencies = ()
1118
+ )
1119
+ )
1120
+ )
1121
+ if issue_warning :
1122
+ warnings .warn (
1123
+ "Computing mixed collections that are backed by "
1124
+ "HighlevelGraphs/dicts and Expressions. "
1125
+ "This forces Expressions to be materialized. "
1126
+ "It is recommended to use only one type and separate the dask."
1127
+ "compute calls if necessary." ,
1128
+ UserWarning ,
1129
+ )
1130
+ if not hlgs :
1131
+ return None
1132
+ return _HLGExprSequence (* hlgs )
1133
+
1134
+
1135
+ class FinalizeCompute (Expr ):
1136
+ _parameters = ["expr" ]
1137
+
1138
+ def _simplify_down (self ):
1139
+ return self .expr .finalize_compute ()
1140
+
1141
+
1142
+ def _convert_dask_keys (keys ):
1143
+ from dask ._task_spec import List , TaskRef
1144
+
1145
+ assert isinstance (keys , list )
1146
+ new_keys = []
1147
+ for key in keys :
1148
+ if isinstance (key , list ):
1149
+ new_keys .append (_convert_dask_keys (key ))
1150
+ else :
1151
+ new_keys .append (TaskRef (key ))
1152
+ return List (* new_keys )
1153
+
1154
+
1155
+ class HLGFinalizeCompute (Expr ):
1156
+ _parameters = ["dsk" ]
1157
+
1158
+ def __dask_annotations__ (self ):
1159
+ return self .dsk .__dask_annotations__ ()
1160
+
1161
+ def _simplify_down (self ):
1162
+ from dask .delayed import Delayed
1163
+
1164
+ # Skip finalization for Delayed
1165
+ if self .dsk .postcompute () == Delayed .__dask_postcompute__ (self .dsk ):
1166
+ return self .dsk
1167
+ return self
1168
+
1169
+ @property
1170
+ def _name (self ):
1171
+ return f"finalize-{ self .deterministic_token } "
1172
+
1173
+ def _layer (self ) -> dict :
1174
+ expr = self .operand ("dsk" )
1175
+ dsk = expr ._layer ().copy ()
1176
+
1177
+ func , extra_args = expr .postcompute ()
1178
+ keys = expr .__dask_keys__ ()
1179
+
1180
+ t = Task (self ._name , func , _convert_dask_keys (keys ), * extra_args )
1181
+ dsk [t .key ] = t
1182
+ return dsk
1183
+
1184
+ def __dask_keys__ (self ):
1185
+ return [self ._name ]
0 commit comments