1313# limitations under the License.
1414
1515import weakref
16- from typing import NamedTuple , Optional
16+ from abc import ABC
17+ from typing import NamedTuple , Optional , Type , Set
1718
1819import numpy as np
1920from pandas .api .types import is_scalar
2021
2122from .... import dataframe as md
22- from ....core import Tileable , get_output_types , ENTITY_TYPE
23+ from ....core import Tileable , get_output_types , ENTITY_TYPE , TileableGraph
24+ from ....core .graph import EntityGraph
2325from ....dataframe .arithmetic .core import DataFrameUnaryUfunc , DataFrameBinopUfunc
2426from ....dataframe .base .eval import DataFrameEval
2527from ....dataframe .indexing .getitem import DataFrameIndex
2628from ....dataframe .indexing .setitem import DataFrameSetitem
27- from ....typing import OperandType
29+ from ....typing import OperandType , EntityType
2830from ....utils import implements
29- from ..core import OptimizationRecord , OptimizationRecordType
31+ from ..core import OptimizationRecord , OptimizationRecordType , OptimizationRecords
3032from ..tileable .core import register_operand_based_optimization_rule
3133from .core import OperandBasedOptimizationRule
3234
@@ -66,8 +68,70 @@ def builder(lhs: str, rhs: str):
6668_extract_result_cache = weakref .WeakKeyDictionary ()
6769
6870
71+ class _EvalRewriteOptimizationRule (OperandBasedOptimizationRule , ABC ):
72+ def __init__ (
73+ self ,
74+ graph : EntityGraph ,
75+ records : OptimizationRecords ,
76+ optimizer_cls : Type ["Optimizer" ],
77+ ):
78+ super ().__init__ (graph , records , optimizer_cls )
79+ self ._marked_predecessors = dict ()
80+
81+ def _mark_predecessor (self , node : EntityType , predecessor : EntityType ):
82+ pred_original = self ._records .get_original_entity (predecessor , predecessor )
83+ if predecessor not in self ._marked_predecessors :
84+ self ._marked_predecessors [pred_original ] = {node }
85+ else :
86+ self ._marked_predecessors [pred_original ].add (node )
87+
88+ def _find_nodes_to_remove (self , node : EntityType ) -> Set [EntityType ]:
89+ node = self ._records .get_optimization_result (node ) or node
90+ removed_nodes = {node }
91+ results_set = set (self ._graph .results )
92+ removed_pairs = []
93+ for pred in self ._graph .iter_predecessors (node ):
94+ pred_original = self ._records .get_original_entity (pred , pred )
95+ pred_opt = self ._records .get_optimization_result (pred , pred )
96+
97+ if pred_opt in results_set or pred_original in results_set :
98+ continue
99+
100+ affect_succ = self ._marked_predecessors .get (pred_original ) or []
101+ affect_succ_opt = [
102+ self ._records .get_optimization_result (s , s ) for s in affect_succ
103+ ]
104+ if all (s in affect_succ_opt for s in self ._graph .iter_successors (pred )):
105+ removed_pairs .append ((pred_original , pred_opt ))
106+
107+ for pred_original , pred_opt in removed_pairs :
108+ removed_nodes .add (pred_opt )
109+ self ._records .append_record (
110+ OptimizationRecord (pred_original , None , OptimizationRecordType .delete )
111+ )
112+ return removed_nodes
113+
114+ def _replace_with_new_node (self , original_node : EntityType , new_node : EntityType ):
115+ # Find all the nodes to remove
116+ nodes_to_remove = self ._find_nodes_to_remove (original_node )
117+
118+ # Build the replaced subgraph
119+ subgraph = TileableGraph ()
120+ subgraph .add_node (new_node )
121+
122+ new_results = [new_node ] if new_node in self ._graph .results else None
123+ self ._replace_subgraph (subgraph , nodes_to_remove , new_results )
124+ self ._records .append_record (
125+ OptimizationRecord (
126+ self ._records .get_original_entity (original_node , original_node ),
127+ new_node ,
128+ OptimizationRecordType .replace ,
129+ )
130+ )
131+
132+
69133@register_operand_based_optimization_rule ([DataFrameUnaryUfunc , DataFrameBinopUfunc ])
70- class SeriesArithmeticToEval (OperandBasedOptimizationRule ):
134+ class SeriesArithmeticToEval (_EvalRewriteOptimizationRule ):
71135 _var_counter = 0
72136
73137 @classmethod
@@ -151,7 +215,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
151215 if in_tileable is None :
152216 return EvalExtractRecord ()
153217
154- self ._add_collapsable_predecessor (tileable , op .inputs [0 ])
218+ self ._mark_predecessor (tileable , op .inputs [0 ])
155219 return EvalExtractRecord (
156220 in_tileable , _func_name_to_builder [func_name ](expr ), variables
157221 )
@@ -164,10 +228,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:
164228
165229 lhs_tileable , lhs_expr , lhs_vars = self ._extract_eval_expression (op .lhs )
166230 if lhs_tileable is not None :
167- self ._add_collapsable_predecessor (tileable , op .lhs )
231+ self ._mark_predecessor (tileable , op .lhs )
168232 rhs_tileable , rhs_expr , rhs_vars = self ._extract_eval_expression (op .rhs )
169233 if rhs_tileable is not None :
170- self ._add_collapsable_predecessor (tileable , op .rhs )
234+ self ._mark_predecessor (tileable , op .rhs )
171235
172236 if lhs_expr is None or rhs_expr is None :
173237 return EvalExtractRecord ()
@@ -204,24 +268,10 @@ def apply_to_operand(self, op: OperandType):
204268 new_node = new_op .new_tileable (
205269 [opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
206270 ).data
271+ self ._replace_with_new_node (node , new_node )
207272
208- self ._remove_collapsable_predecessors (node )
209- self ._replace_node (node , new_node )
210- self ._graph .add_edge (opt_in_tileable , new_node )
211273
212- self ._records .append_record (
213- OptimizationRecord (node , new_node , OptimizationRecordType .replace )
214- )
215-
216- # check node if it's in result
217- try :
218- i = self ._graph .results .index (node )
219- self ._graph .results [i ] = new_node
220- except ValueError :
221- pass
222-
223-
224- class _DataFrameEvalRewriteRule (OperandBasedOptimizationRule ):
274+ class _DataFrameEvalRewriteRule (_EvalRewriteOptimizationRule ):
225275 @implements (OperandBasedOptimizationRule .match_operand )
226276 def match_operand (self , op : OperandType ) -> bool :
227277 optimized_eval_op = self ._get_optimized_eval_op (op )
@@ -245,16 +295,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
245295 def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
246296 raise NotImplementedError
247297
248- def _update_op_node (self , old_node : ENTITY_TYPE , new_node : ENTITY_TYPE ):
249- self ._replace_node (old_node , new_node )
250- for in_tileable in new_node .inputs :
251- self ._graph .add_edge (in_tileable , new_node )
252-
253- original_node = self ._records .get_original_entity (old_node , old_node )
254- self ._records .append_record (
255- OptimizationRecord (original_node , new_node , OptimizationRecordType .replace )
256- )
257-
258298 @implements (OperandBasedOptimizationRule .apply_to_operand )
259299 def apply_to_operand (self , op : DataFrameIndex ):
260300 node = op .outputs [0 ]
@@ -268,10 +308,8 @@ def apply_to_operand(self, op: DataFrameIndex):
268308 new_node = new_op .new_tileable (
269309 [opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
270310 ).data
271-
272- self ._add_collapsable_predecessor (node , in_columnar_node )
273- self ._remove_collapsable_predecessors (node )
274- self ._update_op_node (node , new_node )
311+ self ._mark_predecessor (node , in_columnar_node )
312+ self ._replace_with_new_node (node , new_node )
275313
276314
277315@register_operand_based_optimization_rule ([DataFrameIndex ])
@@ -360,7 +398,5 @@ def apply_to_operand(self, op: DataFrameIndex):
360398 new_node = new_op .new_tileable (
361399 pred_opt_node .inputs , _key = node .key , _id = node .id , ** node .params
362400 ).data
363-
364- self ._add_collapsable_predecessor (opt_node , pred_opt_node )
365- self ._remove_collapsable_predecessors (opt_node )
366- self ._update_op_node (opt_node , new_node )
401+ self ._mark_predecessor (opt_node , pred_opt_node )
402+ self ._replace_with_new_node (opt_node , new_node )
0 commit comments