Skip to content

Commit e5c851c

Browse files
authored
Fix error on dependent DataFrame setitems (#2701)
1 parent c7d2bfb commit e5c851c

File tree

3 files changed

+116
-56
lines changed

3 files changed

+116
-56
lines changed

mars/optimization/logical/core.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import defaultdict
1818
from dataclasses import dataclass
1919
from enum import Enum
20-
from typing import Dict, List, Tuple, Type
20+
from typing import Dict, List, Optional, Tuple, Type
2121

2222
from ...core import OperandType, ChunkType, EntityType, enter_mode
2323
from ...core.graph import EntityGraph
@@ -59,10 +59,12 @@ def append_record(self, record: OptimizationRecord):
5959
):
6060
self._optimized_chunk_to_records[record.new_chunk] = record
6161

62-
def get_optimization_result(self, original_chunk: ChunkType) -> ChunkType:
62+
def get_optimization_result(
63+
self, original_chunk: ChunkType, default: Optional[ChunkType] = None
64+
) -> ChunkType:
6365
chunk = original_chunk
6466
if chunk not in self._original_chunk_to_records:
65-
return
67+
return default
6668
while chunk in self._original_chunk_to_records:
6769
record = self._original_chunk_to_records[chunk]
6870
if record.record_type == OptimizationRecordType.replace:
@@ -72,10 +74,12 @@ def get_optimization_result(self, original_chunk: ChunkType) -> ChunkType:
7274
return None
7375
return chunk
7476

75-
def get_original_chunk(self, optimized_chunk: ChunkType) -> ChunkType:
77+
def get_original_chunk(
78+
self, optimized_chunk: ChunkType, default: Optional[ChunkType] = None
79+
) -> ChunkType:
7680
chunk = optimized_chunk
7781
if chunk not in self._optimized_chunk_to_records:
78-
return
82+
return default
7983
while chunk in self._optimized_chunk_to_records:
8084
record = self._optimized_chunk_to_records[chunk]
8185
if record.record_type == OptimizationRecordType.replace:
@@ -151,28 +155,25 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
151155
for succ in successors:
152156
self._graph.add_edge(new_node, succ)
153157

154-
@classmethod
155-
def _add_collapsable_predecessor(cls, node: EntityType, predecessor: EntityType):
156-
if predecessor not in cls._preds_to_remove:
157-
cls._preds_to_remove[predecessor] = {node}
158+
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
159+
pred_original = self._records.get_original_chunk(predecessor, predecessor)
160+
if predecessor not in self._preds_to_remove:
161+
self._preds_to_remove[pred_original] = {node}
158162
else:
159-
cls._preds_to_remove[predecessor].add(node)
163+
self._preds_to_remove[pred_original].add(node)
160164

161165
def _remove_collapsable_predecessors(self, node: EntityType):
162166
node = self._records.get_optimization_result(node) or node
163167
preds_opt_to_remove = []
164168
for pred in self._graph.predecessors(node):
165-
pred_original = self._records.get_original_chunk(pred)
166-
pred_original = pred_original if pred_original is not None else pred
167-
168-
pred_opt = self._records.get_optimization_result(pred)
169-
pred_opt = pred_opt if pred_opt is not None else pred
169+
pred_original = self._records.get_original_chunk(pred, pred)
170+
pred_opt = self._records.get_optimization_result(pred, pred)
170171

171172
if pred_opt in self._graph.results or pred_original in self._graph.results:
172173
continue
173174
affect_succ = self._preds_to_remove.get(pred_original) or []
174175
affect_succ_opt = [
175-
self._records.get_optimization_result(s) or s for s in affect_succ
176+
self._records.get_optimization_result(s, s) for s in affect_succ
176177
]
177178
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
178179
preds_opt_to_remove.append((pred_original, pred_opt))

mars/optimization/logical/tileable/arithmetic_query.py

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,12 @@ def _is_select_dataframe_column(tileable) -> bool:
105105
and index_op.mask is None
106106
)
107107

108-
@classmethod
109-
def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
108+
def _extract_eval_expression(self, tileable) -> EvalExtractRecord:
110109
if is_scalar(tileable):
111110
if isinstance(tileable, (int, bool, str, bytes, np.integer, np.bool_)):
112111
return EvalExtractRecord(expr=repr(tileable))
113112
else:
114-
var_name = f"__eval_scalar_var{cls._next_var_id()}"
113+
var_name = f"__eval_scalar_var{self._next_var_id()}"
115114
var_dict = {var_name: tileable}
116115
return EvalExtractRecord(expr=f"@{var_name}", variables=var_dict)
117116

@@ -121,15 +120,15 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
121120
if tileable in _extract_result_cache:
122121
return _extract_result_cache[tileable]
123122

124-
if cls._is_select_dataframe_column(tileable):
125-
result = cls._extract_column_select(tileable)
123+
if self._is_select_dataframe_column(tileable):
124+
result = self._extract_column_select(tileable)
126125
elif isinstance(tileable.op, DataFrameUnaryUfunc):
127-
result = cls._extract_unary(tileable)
126+
result = self._extract_unary(tileable)
128127
elif isinstance(tileable.op, DataFrameBinopUfunc):
129128
if tileable.op.fill_value is not None or tileable.op.level is not None:
130129
result = EvalExtractRecord()
131130
else:
132-
result = cls._extract_binary(tileable)
131+
result = self._extract_binary(tileable)
133132
else:
134133
result = EvalExtractRecord()
135134

@@ -140,35 +139,33 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
140139
def _extract_column_select(cls, tileable) -> EvalExtractRecord:
141140
return EvalExtractRecord(tileable.inputs[0], f"`{tileable.op.col_names}`")
142141

143-
@classmethod
144-
def _extract_unary(cls, tileable) -> EvalExtractRecord:
142+
def _extract_unary(self, tileable) -> EvalExtractRecord:
145143
op = tileable.op
146144
func_name = getattr(op, "_func_name") or getattr(op, "_bin_func_name")
147145
if func_name not in _func_name_to_builder: # pragma: no cover
148146
return EvalExtractRecord()
149147

150-
in_tileable, expr, variables = cls._extract_eval_expression(op.inputs[0])
148+
in_tileable, expr, variables = self._extract_eval_expression(op.inputs[0])
151149
if in_tileable is None:
152150
return EvalExtractRecord()
153151

154-
cls._add_collapsable_predecessor(tileable, op.inputs[0])
152+
self._add_collapsable_predecessor(tileable, op.inputs[0])
155153
return EvalExtractRecord(
156154
in_tileable, _func_name_to_builder[func_name](expr), variables
157155
)
158156

159-
@classmethod
160-
def _extract_binary(cls, tileable) -> EvalExtractRecord:
157+
def _extract_binary(self, tileable) -> EvalExtractRecord:
161158
op = tileable.op
162159
func_name = getattr(op, "_func_name", None) or getattr(op, "_bit_func_name")
163160
if func_name not in _func_name_to_builder: # pragma: no cover
164161
return EvalExtractRecord()
165162

166-
lhs_tileable, lhs_expr, lhs_vars = cls._extract_eval_expression(op.lhs)
163+
lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs)
167164
if lhs_tileable is not None:
168-
cls._add_collapsable_predecessor(tileable, op.lhs)
169-
rhs_tileable, rhs_expr, rhs_vars = cls._extract_eval_expression(op.rhs)
165+
self._add_collapsable_predecessor(tileable, op.lhs)
166+
rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs)
170167
if rhs_tileable is not None:
171-
cls._add_collapsable_predecessor(tileable, op.rhs)
168+
self._add_collapsable_predecessor(tileable, op.rhs)
172169

173170
if lhs_expr is None or rhs_expr is None:
174171
return EvalExtractRecord()
@@ -190,6 +187,9 @@ def _extract_binary(cls, tileable) -> EvalExtractRecord:
190187
def apply(self, op: OperandType):
191188
node = op.outputs[0]
192189
in_tileable, expr, variables = self._extract_eval_expression(node)
190+
opt_in_tileable = self._records.get_optimization_result(
191+
in_tileable, in_tileable
192+
)
193193

194194
new_op = DataFrameEval(
195195
_key=node.op.key,
@@ -199,13 +199,13 @@ def apply(self, op: OperandType):
199199
parser="pandas",
200200
is_query=False,
201201
)
202-
new_node = new_op.new_tileable([in_tileable], **node.params).data
203-
new_node._key = node.key
204-
new_node._id = node.id
202+
new_node = new_op.new_tileable(
203+
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
204+
).data
205205

206206
self._remove_collapsable_predecessors(node)
207207
self._replace_node(node, new_node)
208-
self._graph.add_edge(in_tileable, new_node)
208+
self._graph.add_edge(opt_in_tileable, new_node)
209209

210210
self._records.append_record(
211211
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
@@ -241,34 +241,40 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
241241
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
242242
raise NotImplementedError
243243

244-
def apply(self, op: DataFrameIndex):
245-
node = op.outputs[0]
246-
in_tileable = op.inputs[0]
247-
in_columnar_node = self._get_input_columnar_node(op)
248-
249-
new_op = self._build_new_eval_op(op)
250-
new_op._key = node.op.key
251-
252-
new_node = new_op.new_tileable([in_tileable], **node.params).data
253-
new_node._key = node.key
254-
new_node._id = node.id
255-
256-
self._add_collapsable_predecessor(node, in_columnar_node)
257-
self._remove_collapsable_predecessors(node)
244+
def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE):
245+
self._replace_node(old_node, new_node)
246+
for in_tileable in new_node.inputs:
247+
self._graph.add_edge(in_tileable, new_node)
258248

259-
self._replace_node(node, new_node)
260-
self._graph.add_edge(in_tileable, new_node)
249+
original_node = self._records.get_original_chunk(old_node, old_node)
261250
self._records.append_record(
262-
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
251+
OptimizationRecord(original_node, new_node, OptimizationRecordType.replace)
263252
)
264253

265254
# check node if it's in result
266255
try:
267-
i = self._graph.results.index(node)
256+
i = self._graph.results.index(old_node)
268257
self._graph.results[i] = new_node
269258
except ValueError:
270259
pass
271260

261+
def apply(self, op: DataFrameIndex):
262+
node = op.outputs[0]
263+
in_tileable = op.inputs[0]
264+
in_columnar_node = self._get_input_columnar_node(op)
265+
opt_in_tileable = self._records.get_optimization_result(
266+
in_tileable, in_tileable
267+
)
268+
269+
new_op = self._build_new_eval_op(op)
270+
new_node = new_op.new_tileable(
271+
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
272+
).data
273+
274+
self._add_collapsable_predecessor(node, in_columnar_node)
275+
self._remove_collapsable_predecessors(node)
276+
self._update_op_node(node, new_node)
277+
272278

273279
@register_tileable_optimization_rule([DataFrameIndex])
274280
class DataFrameBoolEvalToQuery(_DataFrameEvalRewriteRule):
@@ -287,6 +293,7 @@ def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
287293
def _build_new_eval_op(self, op: OperandType):
288294
in_eval_op = self._get_optimized_eval_op(op)
289295
return DataFrameEval(
296+
_key=op.key,
290297
_output_types=get_output_types(op.outputs[0]),
291298
expr=in_eval_op.expr,
292299
variables=in_eval_op.variables,
@@ -308,10 +315,51 @@ def _get_input_columnar_node(self, op: DataFrameSetitem) -> ENTITY_TYPE:
308315
def _build_new_eval_op(self, op: DataFrameSetitem):
309316
in_eval_op = self._get_optimized_eval_op(op)
310317
return DataFrameEval(
318+
_key=op.key,
311319
_output_types=get_output_types(op.outputs[0]),
312320
expr=f"`{op.indexes}` = {in_eval_op.expr}",
313321
variables=in_eval_op.variables,
314322
parser="pandas",
315323
is_query=False,
316324
self_target=True,
317325
)
326+
327+
def apply(self, op: DataFrameIndex):
328+
super().apply(op)
329+
330+
node = op.outputs[0]
331+
opt_node = self._records.get_optimization_result(node, node)
332+
if not isinstance(opt_node.op, DataFrameEval): # pragma: no cover
333+
return
334+
335+
# when encountering consecutive SetItems, expressions can be
336+
# merged as a multiline expression
337+
pred_opt_node = opt_node.inputs[0]
338+
if (
339+
isinstance(pred_opt_node.op, DataFrameEval)
340+
and opt_node.op.parser == pred_opt_node.op.parser == "pandas"
341+
and not opt_node.op.is_query
342+
and not pred_opt_node.op.is_query
343+
and opt_node.op.self_target
344+
and pred_opt_node.op.self_target
345+
):
346+
new_expr = pred_opt_node.op.expr + "\n" + opt_node.op.expr
347+
new_variables = (pred_opt_node.op.variables or dict()).copy()
348+
new_variables.update(opt_node.op.variables or dict())
349+
350+
new_op = DataFrameEval(
351+
_key=op.key,
352+
_output_types=get_output_types(op.outputs[0]),
353+
expr=new_expr,
354+
variables=new_variables,
355+
parser="pandas",
356+
is_query=False,
357+
self_target=True,
358+
)
359+
new_node = new_op.new_tileable(
360+
pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params
361+
).data
362+
363+
self._add_collapsable_predecessor(opt_node, pred_opt_node)
364+
self._remove_collapsable_predecessors(opt_node)
365+
self._update_op_node(opt_node, new_node)

mars/optimization/logical/tileable/tests/test_arithmetic_query.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,25 @@ def test_eval_setitem_to_eval(setup):
153153
df2 = md.DataFrame(raw2, chunk_size=10)
154154
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
155155
df3["K"] = df3["A"] * (1 - df3["B"])
156+
df3["L"] = df3["K"] - df3["A"]
157+
df3["M"] = df3["K"] + df3["L"]
158+
156159
graph = TileableGraph([df3.data])
157160
next(TileableGraphBuilder(graph).build())
158161
records = optimize(graph)
159162
opt_df3 = records.get_optimization_result(df3.data)
160-
assert opt_df3.op.expr == "`K` = (`A`) * ((1) - (`B`))"
163+
assert opt_df3.op.expr == "\n".join(
164+
[
165+
"`K` = (`A`) * ((1) - (`B`))",
166+
"`L` = (`K`) - (`A`)",
167+
"`M` = (`K`) + (`L`)",
168+
]
169+
)
161170
assert len(graph) == 4
162171
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1
163172

164173
r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
165174
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
175+
r_df3["L"] = r_df3["K"] - r_df3["A"]
176+
r_df3["M"] = r_df3["K"] + r_df3["L"]
166177
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)

0 commit comments

Comments
 (0)