Skip to content

Commit a446eaa

Browse files
authored
Eliminate redundant eval node in optimization (#2683)
1 parent 4fbf841 commit a446eaa

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

mars/optimization/logical/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,12 @@ def _remove_collapsable_predecessors(self, node: EntityType):
162162
node = self._records.get_optimization_result(node) or node
163163
preds_opt_to_remove = []
164164
for pred in self._graph.predecessors(node):
165-
pred_original = self._records.get_original_chunk(pred) or pred
166-
pred_opt = self._records.get_optimization_result(pred) or pred
165+
pred_original = self._records.get_original_chunk(pred)
166+
pred_original = pred_original if pred_original is not None else pred
167+
168+
pred_opt = self._records.get_optimization_result(pred)
169+
pred_opt = pred_opt if pred_opt is not None else pred
170+
167171
if pred_opt in self._graph.results or pred_original in self._graph.results:
168172
continue
169173
affect_succ = self._preds_to_remove.get(pred_original) or []

mars/optimization/logical/tileable/tests/test_arithmetic_query.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def test_arithmetic_query(setup):
7474
r_df2, _r_col_a = fetch(execute(df2, df1["A"]))
7575
pd.testing.assert_series_equal(r_df2, -raw["A"] + raw["B"] * 5 + 3 * raw["C"])
7676

77+
df1 = md.DataFrame(raw, chunk_size=10)
78+
df2 = md.DataFrame(raw2, chunk_size=10)
79+
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
80+
df3["K"] = df4 = df3["A"] * (1 - df3["B"])
81+
graph = TileableGraph([df3.data])
82+
next(TileableGraphBuilder(graph).build())
83+
records = optimize(graph)
84+
opt_df4 = records.get_optimization_result(df4.data)
85+
assert opt_df4.op.expr == "(`A`) * ((1) - (`B`))"
86+
assert len(graph) == 5
87+
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1
88+
89+
r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
90+
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
91+
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)
92+
7793

7894
@enter_mode(build=True)
7995
def test_bool_eval_to_query(setup):

0 commit comments

Comments
 (0)