1313# limitations under the License.
1414
1515import weakref
16- from typing import NamedTuple , Optional
16+ from abc import ABC
17+ from typing import NamedTuple , Optional , Type , Set
1718
1819import numpy as np
1920from pandas .api .types import is_scalar
2021
2122from .... import dataframe as md
22- from ....core import Tileable , get_output_types , ENTITY_TYPE
23+ from ....core import Tileable , get_output_types , ENTITY_TYPE , TileableGraph
24+ from ....core .graph import EntityGraph
2325from ....dataframe .arithmetic .core import DataFrameUnaryUfunc , DataFrameBinopUfunc
2426from ....dataframe .base .eval import DataFrameEval
2527from ....dataframe .indexing .getitem import DataFrameIndex
2628from ....dataframe .indexing .setitem import DataFrameSetitem
27- from ....typing import OperandType
29+ from ....typing import OperandType , EntityType
2830from ....utils import implements
29- from ..core import OptimizationRecord , OptimizationRecordType
31+ from ..core import (
32+ OptimizationRecord ,
33+ OptimizationRecordType ,
34+ OptimizationRecords ,
35+ Optimizer ,
36+ )
3037from ..tileable .core import register_operand_based_optimization_rule
3138from .core import OperandBasedOptimizationRule
3239
@@ -66,8 +73,70 @@ def builder(lhs: str, rhs: str):
6673_extract_result_cache = weakref .WeakKeyDictionary ()
6774
6875
76+ class _EvalRewriteOptimizationRule (OperandBasedOptimizationRule , ABC ):
77+ def __init__ (
78+ self ,
79+ graph : EntityGraph ,
80+ records : OptimizationRecords ,
81+ optimizer_cls : Type [Optimizer ],
82+ ):
83+ super ().__init__ (graph , records , optimizer_cls )
84+ self ._marked_predecessors = dict ()
85+
86+ def _mark_predecessor (self , node : EntityType , predecessor : EntityType ):
87+ pred_original = self ._records .get_original_entity (predecessor , predecessor )
88+ if predecessor not in self ._marked_predecessors :
89+ self ._marked_predecessors [pred_original ] = {node }
90+ else :
91+ self ._marked_predecessors [pred_original ].add (node )
92+
93+ def _find_nodes_to_remove (self , node : EntityType ) -> Set [EntityType ]:
94+ node = self ._records .get_optimization_result (node ) or node
95+ removed_nodes = {node }
96+ results_set = set (self ._graph .results )
97+ removed_pairs = []
98+ for pred in self ._graph .iter_predecessors (node ):
99+ pred_original = self ._records .get_original_entity (pred , pred )
100+ pred_opt = self ._records .get_optimization_result (pred , pred )
101+
102+ if pred_opt in results_set or pred_original in results_set :
103+ continue
104+
105+ affect_succ = self ._marked_predecessors .get (pred_original ) or []
106+ affect_succ_opt = [
107+ self ._records .get_optimization_result (s , s ) for s in affect_succ
108+ ]
109+ if all (s in affect_succ_opt for s in self ._graph .iter_successors (pred )):
110+ removed_pairs .append ((pred_original , pred_opt ))
111+
112+ for pred_original , pred_opt in removed_pairs :
113+ removed_nodes .add (pred_opt )
114+ self ._records .append_record (
115+ OptimizationRecord (pred_original , None , OptimizationRecordType .delete )
116+ )
117+ return removed_nodes
118+
119+ def _replace_with_new_node (self , original_node : EntityType , new_node : EntityType ):
120+ # Find all the nodes to remove
121+ nodes_to_remove = self ._find_nodes_to_remove (original_node )
122+
123+ # Build the replaced subgraph
124+ subgraph = TileableGraph ()
125+ subgraph .add_node (new_node )
126+
127+ new_results = [new_node ] if new_node in self ._graph .results else None
128+ self ._replace_subgraph (subgraph , nodes_to_remove , new_results )
129+ self ._records .append_record (
130+ OptimizationRecord (
131+ self ._records .get_original_entity (original_node , original_node ),
132+ new_node ,
133+ OptimizationRecordType .replace ,
134+ )
135+ )
136+
137+
69138@register_operand_based_optimization_rule ([DataFrameUnaryUfunc , DataFrameBinopUfunc ])
70- class SeriesArithmeticToEval (OperandBasedOptimizationRule ):
139+ class SeriesArithmeticToEval (_EvalRewriteOptimizationRule ):
71140 _var_counter = 0
72141
73142 @classmethod
@@ -151,7 +220,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
151220 if in_tileable is None :
152221 return EvalExtractRecord ()
153222
154- self ._add_collapsable_predecessor (tileable , op .inputs [0 ])
223+ self ._mark_predecessor (tileable , op .inputs [0 ])
155224 return EvalExtractRecord (
156225 in_tileable , _func_name_to_builder [func_name ](expr ), variables
157226 )
@@ -164,10 +233,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:
164233
165234 lhs_tileable , lhs_expr , lhs_vars = self ._extract_eval_expression (op .lhs )
166235 if lhs_tileable is not None :
167- self ._add_collapsable_predecessor (tileable , op .lhs )
236+ self ._mark_predecessor (tileable , op .lhs )
168237 rhs_tileable , rhs_expr , rhs_vars = self ._extract_eval_expression (op .rhs )
169238 if rhs_tileable is not None :
170- self ._add_collapsable_predecessor (tileable , op .rhs )
239+ self ._mark_predecessor (tileable , op .rhs )
171240
172241 if lhs_expr is None or rhs_expr is None :
173242 return EvalExtractRecord ()
@@ -204,24 +273,10 @@ def apply_to_operand(self, op: OperandType):
204273 new_node = new_op .new_tileable (
205274 [opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
206275 ).data
276+ self ._replace_with_new_node (node , new_node )
207277
208- self ._remove_collapsable_predecessors (node )
209- self ._replace_node (node , new_node )
210- self ._graph .add_edge (opt_in_tileable , new_node )
211278
212- self ._records .append_record (
213- OptimizationRecord (node , new_node , OptimizationRecordType .replace )
214- )
215-
216- # check node if it's in result
217- try :
218- i = self ._graph .results .index (node )
219- self ._graph .results [i ] = new_node
220- except ValueError :
221- pass
222-
223-
224- class _DataFrameEvalRewriteRule (OperandBasedOptimizationRule ):
279+ class _DataFrameEvalRewriteRule (_EvalRewriteOptimizationRule ):
225280 @implements (OperandBasedOptimizationRule .match_operand )
226281 def match_operand (self , op : OperandType ) -> bool :
227282 optimized_eval_op = self ._get_optimized_eval_op (op )
@@ -245,16 +300,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
245300 def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
246301 raise NotImplementedError
247302
248- def _update_op_node (self , old_node : ENTITY_TYPE , new_node : ENTITY_TYPE ):
249- self ._replace_node (old_node , new_node )
250- for in_tileable in new_node .inputs :
251- self ._graph .add_edge (in_tileable , new_node )
252-
253- original_node = self ._records .get_original_entity (old_node , old_node )
254- self ._records .append_record (
255- OptimizationRecord (original_node , new_node , OptimizationRecordType .replace )
256- )
257-
258303 @implements (OperandBasedOptimizationRule .apply_to_operand )
259304 def apply_to_operand (self , op : DataFrameIndex ):
260305 node = op .outputs [0 ]
@@ -268,10 +313,8 @@ def apply_to_operand(self, op: DataFrameIndex):
268313 new_node = new_op .new_tileable (
269314 [opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
270315 ).data
271-
272- self ._add_collapsable_predecessor (node , in_columnar_node )
273- self ._remove_collapsable_predecessors (node )
274- self ._update_op_node (node , new_node )
316+ self ._mark_predecessor (node , in_columnar_node )
317+ self ._replace_with_new_node (node , new_node )
275318
276319
277320@register_operand_based_optimization_rule ([DataFrameIndex ])
@@ -360,7 +403,5 @@ def apply_to_operand(self, op: DataFrameIndex):
360403 new_node = new_op .new_tileable (
361404 pred_opt_node .inputs , _key = node .key , _id = node .id , ** node .params
362405 ).data
363-
364- self ._add_collapsable_predecessor (opt_node , pred_opt_node )
365- self ._remove_collapsable_predecessors (opt_node )
366- self ._update_op_node (opt_node , new_node )
406+ self ._mark_predecessor (opt_node , pred_opt_node )
407+ self ._replace_with_new_node (opt_node , new_node )
0 commit comments