Skip to content

Commit 8505806

Browse files
committed
Refine ArithmeticToEval related rules
1 parent dd02ba9 commit 8505806

File tree

3 files changed

+84
-78
lines changed

3 files changed

+84
-78
lines changed

mars/optimization/logical/core.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import functools
1515
import itertools
16-
import weakref
1716
from abc import ABC, abstractmethod
1817
from collections import defaultdict
1918
from dataclasses import dataclass
@@ -92,8 +91,6 @@ def get_original_entity(
9291

9392

9493
class OptimizationRule(ABC):
95-
_preds_to_remove = weakref.WeakKeyDictionary()
96-
9794
def __init__(
9895
self,
9996
graph: EntityGraph,
@@ -217,35 +214,6 @@ def _replace_subgraph(
217214
for result in new_results:
218215
self._graph.results[result_indices[result.key]] = result
219216

220-
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
221-
pred_original = self._records.get_original_entity(predecessor, predecessor)
222-
if predecessor not in self._preds_to_remove:
223-
self._preds_to_remove[pred_original] = {node}
224-
else:
225-
self._preds_to_remove[pred_original].add(node)
226-
227-
def _remove_collapsable_predecessors(self, node: EntityType):
228-
node = self._records.get_optimization_result(node) or node
229-
preds_opt_to_remove = []
230-
for pred in self._graph.predecessors(node):
231-
pred_original = self._records.get_original_entity(pred, pred)
232-
pred_opt = self._records.get_optimization_result(pred, pred)
233-
234-
if pred_opt in self._graph.results or pred_original in self._graph.results:
235-
continue
236-
affect_succ = self._preds_to_remove.get(pred_original) or []
237-
affect_succ_opt = [
238-
self._records.get_optimization_result(s, s) for s in affect_succ
239-
]
240-
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
241-
preds_opt_to_remove.append((pred_original, pred_opt))
242-
243-
for pred_original, pred_opt in preds_opt_to_remove:
244-
self._graph.remove_node(pred_opt)
245-
self._records.append_record(
246-
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
247-
)
248-
249217

250218
class OperandBasedOptimizationRule(OptimizationRule):
251219
"""

mars/optimization/logical/tests/test_core.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,7 @@ def test_replace_null_subgraph():
157157

158158
c1.inputs.clear()
159159
c2.inputs.clear()
160-
r.replace_subgraph(
161-
None,
162-
{key_to_node[op.key] for op in [s1, s2]}
163-
)
160+
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
164161
assert g1.results == expected_results
165162
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
166163
expected_edges = {

mars/optimization/logical/tileable/arithmetic_query.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,27 @@
1313
# limitations under the License.
1414

1515
import weakref
16-
from typing import NamedTuple, Optional
16+
from abc import ABC
17+
from typing import NamedTuple, Optional, Type, Set
1718

1819
import numpy as np
1920
from pandas.api.types import is_scalar
2021

2122
from .... import dataframe as md
22-
from ....core import Tileable, get_output_types, ENTITY_TYPE
23+
from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph
24+
from ....core.graph import EntityGraph
2325
from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc
2426
from ....dataframe.base.eval import DataFrameEval
2527
from ....dataframe.indexing.getitem import DataFrameIndex
2628
from ....dataframe.indexing.setitem import DataFrameSetitem
27-
from ....typing import OperandType
29+
from ....typing import OperandType, EntityType
2830
from ....utils import implements
29-
from ..core import OptimizationRecord, OptimizationRecordType
31+
from ..core import (
32+
OptimizationRecord,
33+
OptimizationRecordType,
34+
OptimizationRecords,
35+
Optimizer,
36+
)
3037
from ..tileable.core import register_operand_based_optimization_rule
3138
from .core import OperandBasedOptimizationRule
3239

@@ -66,8 +73,70 @@ def builder(lhs: str, rhs: str):
6673
_extract_result_cache = weakref.WeakKeyDictionary()
6774

6875

76+
class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC):
77+
def __init__(
78+
self,
79+
graph: EntityGraph,
80+
records: OptimizationRecords,
81+
optimizer_cls: Type[Optimizer],
82+
):
83+
super().__init__(graph, records, optimizer_cls)
84+
self._marked_predecessors = dict()
85+
86+
def _mark_predecessor(self, node: EntityType, predecessor: EntityType):
87+
pred_original = self._records.get_original_entity(predecessor, predecessor)
88+
if predecessor not in self._marked_predecessors:
89+
self._marked_predecessors[pred_original] = {node}
90+
else:
91+
self._marked_predecessors[pred_original].add(node)
92+
93+
def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]:
94+
node = self._records.get_optimization_result(node) or node
95+
removed_nodes = {node}
96+
results_set = set(self._graph.results)
97+
removed_pairs = []
98+
for pred in self._graph.iter_predecessors(node):
99+
pred_original = self._records.get_original_entity(pred, pred)
100+
pred_opt = self._records.get_optimization_result(pred, pred)
101+
102+
if pred_opt in results_set or pred_original in results_set:
103+
continue
104+
105+
affect_succ = self._marked_predecessors.get(pred_original) or []
106+
affect_succ_opt = [
107+
self._records.get_optimization_result(s, s) for s in affect_succ
108+
]
109+
if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)):
110+
removed_pairs.append((pred_original, pred_opt))
111+
112+
for pred_original, pred_opt in removed_pairs:
113+
removed_nodes.add(pred_opt)
114+
self._records.append_record(
115+
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
116+
)
117+
return removed_nodes
118+
119+
def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType):
120+
# Find all the nodes to remove
121+
nodes_to_remove = self._find_nodes_to_remove(original_node)
122+
123+
# Build the replaced subgraph
124+
subgraph = TileableGraph()
125+
subgraph.add_node(new_node)
126+
127+
new_results = [new_node] if new_node in self._graph.results else None
128+
self._replace_subgraph(subgraph, nodes_to_remove, new_results)
129+
self._records.append_record(
130+
OptimizationRecord(
131+
self._records.get_original_entity(original_node, original_node),
132+
new_node,
133+
OptimizationRecordType.replace,
134+
)
135+
)
136+
137+
69138
@register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc])
70-
class SeriesArithmeticToEval(OperandBasedOptimizationRule):
139+
class SeriesArithmeticToEval(_EvalRewriteOptimizationRule):
71140
_var_counter = 0
72141

73142
@classmethod
@@ -151,7 +220,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
151220
if in_tileable is None:
152221
return EvalExtractRecord()
153222

154-
self._add_collapsable_predecessor(tileable, op.inputs[0])
223+
self._mark_predecessor(tileable, op.inputs[0])
155224
return EvalExtractRecord(
156225
in_tileable, _func_name_to_builder[func_name](expr), variables
157226
)
@@ -164,10 +233,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:
164233

165234
lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs)
166235
if lhs_tileable is not None:
167-
self._add_collapsable_predecessor(tileable, op.lhs)
236+
self._mark_predecessor(tileable, op.lhs)
168237
rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs)
169238
if rhs_tileable is not None:
170-
self._add_collapsable_predecessor(tileable, op.rhs)
239+
self._mark_predecessor(tileable, op.rhs)
171240

172241
if lhs_expr is None or rhs_expr is None:
173242
return EvalExtractRecord()
@@ -204,24 +273,10 @@ def apply_to_operand(self, op: OperandType):
204273
new_node = new_op.new_tileable(
205274
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
206275
).data
276+
self._replace_with_new_node(node, new_node)
207277

208-
self._remove_collapsable_predecessors(node)
209-
self._replace_node(node, new_node)
210-
self._graph.add_edge(opt_in_tileable, new_node)
211278

212-
self._records.append_record(
213-
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
214-
)
215-
216-
# check node if it's in result
217-
try:
218-
i = self._graph.results.index(node)
219-
self._graph.results[i] = new_node
220-
except ValueError:
221-
pass
222-
223-
224-
class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule):
279+
class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule):
225280
@implements(OperandBasedOptimizationRule.match_operand)
226281
def match_operand(self, op: OperandType) -> bool:
227282
optimized_eval_op = self._get_optimized_eval_op(op)
@@ -245,16 +300,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
245300
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
246301
raise NotImplementedError
247302

248-
def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE):
249-
self._replace_node(old_node, new_node)
250-
for in_tileable in new_node.inputs:
251-
self._graph.add_edge(in_tileable, new_node)
252-
253-
original_node = self._records.get_original_entity(old_node, old_node)
254-
self._records.append_record(
255-
OptimizationRecord(original_node, new_node, OptimizationRecordType.replace)
256-
)
257-
258303
@implements(OperandBasedOptimizationRule.apply_to_operand)
259304
def apply_to_operand(self, op: DataFrameIndex):
260305
node = op.outputs[0]
@@ -268,10 +313,8 @@ def apply_to_operand(self, op: DataFrameIndex):
268313
new_node = new_op.new_tileable(
269314
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
270315
).data
271-
272-
self._add_collapsable_predecessor(node, in_columnar_node)
273-
self._remove_collapsable_predecessors(node)
274-
self._update_op_node(node, new_node)
316+
self._mark_predecessor(node, in_columnar_node)
317+
self._replace_with_new_node(node, new_node)
275318

276319

277320
@register_operand_based_optimization_rule([DataFrameIndex])
@@ -360,7 +403,5 @@ def apply_to_operand(self, op: DataFrameIndex):
360403
new_node = new_op.new_tileable(
361404
pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params
362405
).data
363-
364-
self._add_collapsable_predecessor(opt_node, pred_opt_node)
365-
self._remove_collapsable_predecessors(opt_node)
366-
self._update_op_node(opt_node, new_node)
406+
self._mark_predecessor(opt_node, pred_opt_node)
407+
self._replace_with_new_node(opt_node, new_node)

0 commit comments

Comments
 (0)