@@ -105,13 +105,12 @@ def _is_select_dataframe_column(tileable) -> bool:
105105 and index_op .mask is None
106106 )
107107
108- @classmethod
109- def _extract_eval_expression (cls , tileable ) -> EvalExtractRecord :
108+ def _extract_eval_expression (self , tileable ) -> EvalExtractRecord :
110109 if is_scalar (tileable ):
111110 if isinstance (tileable , (int , bool , str , bytes , np .integer , np .bool_ )):
112111 return EvalExtractRecord (expr = repr (tileable ))
113112 else :
114- var_name = f"__eval_scalar_var{ cls ._next_var_id ()} "
113+ var_name = f"__eval_scalar_var{ self ._next_var_id ()} "
115114 var_dict = {var_name : tileable }
116115 return EvalExtractRecord (expr = f"@{ var_name } " , variables = var_dict )
117116
@@ -121,15 +120,15 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
121120 if tileable in _extract_result_cache :
122121 return _extract_result_cache [tileable ]
123122
124- if cls ._is_select_dataframe_column (tileable ):
125- result = cls ._extract_column_select (tileable )
123+ if self ._is_select_dataframe_column (tileable ):
124+ result = self ._extract_column_select (tileable )
126125 elif isinstance (tileable .op , DataFrameUnaryUfunc ):
127- result = cls ._extract_unary (tileable )
126+ result = self ._extract_unary (tileable )
128127 elif isinstance (tileable .op , DataFrameBinopUfunc ):
129128 if tileable .op .fill_value is not None or tileable .op .level is not None :
130129 result = EvalExtractRecord ()
131130 else :
132- result = cls ._extract_binary (tileable )
131+ result = self ._extract_binary (tileable )
133132 else :
134133 result = EvalExtractRecord ()
135134
@@ -140,35 +139,33 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
140139 def _extract_column_select (cls , tileable ) -> EvalExtractRecord :
141140 return EvalExtractRecord (tileable .inputs [0 ], f"`{ tileable .op .col_names } `" )
142141
143- @classmethod
144- def _extract_unary (cls , tileable ) -> EvalExtractRecord :
142+ def _extract_unary (self , tileable ) -> EvalExtractRecord :
145143 op = tileable .op
146144 func_name = getattr (op , "_func_name" ) or getattr (op , "_bin_func_name" )
147145 if func_name not in _func_name_to_builder : # pragma: no cover
148146 return EvalExtractRecord ()
149147
150- in_tileable , expr , variables = cls ._extract_eval_expression (op .inputs [0 ])
148+ in_tileable , expr , variables = self ._extract_eval_expression (op .inputs [0 ])
151149 if in_tileable is None :
152150 return EvalExtractRecord ()
153151
154- cls ._add_collapsable_predecessor (tileable , op .inputs [0 ])
152+ self ._add_collapsable_predecessor (tileable , op .inputs [0 ])
155153 return EvalExtractRecord (
156154 in_tileable , _func_name_to_builder [func_name ](expr ), variables
157155 )
158156
159- @classmethod
160- def _extract_binary (cls , tileable ) -> EvalExtractRecord :
157+ def _extract_binary (self , tileable ) -> EvalExtractRecord :
161158 op = tileable .op
162159 func_name = getattr (op , "_func_name" , None ) or getattr (op , "_bit_func_name" )
163160 if func_name not in _func_name_to_builder : # pragma: no cover
164161 return EvalExtractRecord ()
165162
166- lhs_tileable , lhs_expr , lhs_vars = cls ._extract_eval_expression (op .lhs )
163+ lhs_tileable , lhs_expr , lhs_vars = self ._extract_eval_expression (op .lhs )
167164 if lhs_tileable is not None :
168- cls ._add_collapsable_predecessor (tileable , op .lhs )
169- rhs_tileable , rhs_expr , rhs_vars = cls ._extract_eval_expression (op .rhs )
165+ self ._add_collapsable_predecessor (tileable , op .lhs )
166+ rhs_tileable , rhs_expr , rhs_vars = self ._extract_eval_expression (op .rhs )
170167 if rhs_tileable is not None :
171- cls ._add_collapsable_predecessor (tileable , op .rhs )
168+ self ._add_collapsable_predecessor (tileable , op .rhs )
172169
173170 if lhs_expr is None or rhs_expr is None :
174171 return EvalExtractRecord ()
@@ -190,6 +187,9 @@ def _extract_binary(cls, tileable) -> EvalExtractRecord:
190187 def apply (self , op : OperandType ):
191188 node = op .outputs [0 ]
192189 in_tileable , expr , variables = self ._extract_eval_expression (node )
190+ opt_in_tileable = self ._records .get_optimization_result (
191+ in_tileable , in_tileable
192+ )
193193
194194 new_op = DataFrameEval (
195195 _key = node .op .key ,
@@ -199,13 +199,13 @@ def apply(self, op: OperandType):
199199 parser = "pandas" ,
200200 is_query = False ,
201201 )
202- new_node = new_op .new_tileable ([ in_tileable ], ** node . params ). data
203- new_node . _key = node .key
204- new_node . _id = node . id
202+ new_node = new_op .new_tileable (
203+ [ opt_in_tileable ], _key = node . key , _id = node . id , ** node .params
204+ ). data
205205
206206 self ._remove_collapsable_predecessors (node )
207207 self ._replace_node (node , new_node )
208- self ._graph .add_edge (in_tileable , new_node )
208+ self ._graph .add_edge (opt_in_tileable , new_node )
209209
210210 self ._records .append_record (
211211 OptimizationRecord (node , new_node , OptimizationRecordType .replace )
@@ -241,34 +241,40 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
241241 def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
242242 raise NotImplementedError
243243
244- def apply (self , op : DataFrameIndex ):
245- node = op .outputs [0 ]
246- in_tileable = op .inputs [0 ]
247- in_columnar_node = self ._get_input_columnar_node (op )
248-
249- new_op = self ._build_new_eval_op (op )
250- new_op ._key = node .op .key
251-
252- new_node = new_op .new_tileable ([in_tileable ], ** node .params ).data
253- new_node ._key = node .key
254- new_node ._id = node .id
255-
256- self ._add_collapsable_predecessor (node , in_columnar_node )
257- self ._remove_collapsable_predecessors (node )
244+ def _update_op_node (self , old_node : ENTITY_TYPE , new_node : ENTITY_TYPE ):
245+ self ._replace_node (old_node , new_node )
246+ for in_tileable in new_node .inputs :
247+ self ._graph .add_edge (in_tileable , new_node )
258248
259- self ._replace_node (node , new_node )
260- self ._graph .add_edge (in_tileable , new_node )
249+ original_node = self ._records .get_original_chunk (old_node , old_node )
261250 self ._records .append_record (
262- OptimizationRecord (node , new_node , OptimizationRecordType .replace )
251+ OptimizationRecord (original_node , new_node , OptimizationRecordType .replace )
263252 )
264253
265254 # check node if it's in result
266255 try :
267- i = self ._graph .results .index (node )
256+ i = self ._graph .results .index (old_node )
268257 self ._graph .results [i ] = new_node
269258 except ValueError :
270259 pass
271260
261+ def apply (self , op : DataFrameIndex ):
262+ node = op .outputs [0 ]
263+ in_tileable = op .inputs [0 ]
264+ in_columnar_node = self ._get_input_columnar_node (op )
265+ opt_in_tileable = self ._records .get_optimization_result (
266+ in_tileable , in_tileable
267+ )
268+
269+ new_op = self ._build_new_eval_op (op )
270+ new_node = new_op .new_tileable (
271+ [opt_in_tileable ], _key = node .key , _id = node .id , ** node .params
272+ ).data
273+
274+ self ._add_collapsable_predecessor (node , in_columnar_node )
275+ self ._remove_collapsable_predecessors (node )
276+ self ._update_op_node (node , new_node )
277+
272278
273279@register_tileable_optimization_rule ([DataFrameIndex ])
274280class DataFrameBoolEvalToQuery (_DataFrameEvalRewriteRule ):
@@ -287,6 +293,7 @@ def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
287293 def _build_new_eval_op (self , op : OperandType ):
288294 in_eval_op = self ._get_optimized_eval_op (op )
289295 return DataFrameEval (
296+ _key = op .key ,
290297 _output_types = get_output_types (op .outputs [0 ]),
291298 expr = in_eval_op .expr ,
292299 variables = in_eval_op .variables ,
@@ -308,10 +315,51 @@ def _get_input_columnar_node(self, op: DataFrameSetitem) -> ENTITY_TYPE:
308315 def _build_new_eval_op (self , op : DataFrameSetitem ):
309316 in_eval_op = self ._get_optimized_eval_op (op )
310317 return DataFrameEval (
318+ _key = op .key ,
311319 _output_types = get_output_types (op .outputs [0 ]),
312320 expr = f"`{ op .indexes } ` = { in_eval_op .expr } " ,
313321 variables = in_eval_op .variables ,
314322 parser = "pandas" ,
315323 is_query = False ,
316324 self_target = True ,
317325 )
326+
327+ def apply (self , op : DataFrameIndex ):
328+ super ().apply (op )
329+
330+ node = op .outputs [0 ]
331+ opt_node = self ._records .get_optimization_result (node , node )
332+ if not isinstance (opt_node .op , DataFrameEval ): # pragma: no cover
333+ return
334+
335+ # when encountering consecutive SetItems, expressions can be
336+ # merged as a multiline expression
337+ pred_opt_node = opt_node .inputs [0 ]
338+ if (
339+ isinstance (pred_opt_node .op , DataFrameEval )
340+ and opt_node .op .parser == pred_opt_node .op .parser == "pandas"
341+ and not opt_node .op .is_query
342+ and not pred_opt_node .op .is_query
343+ and opt_node .op .self_target
344+ and pred_opt_node .op .self_target
345+ ):
346+ new_expr = pred_opt_node .op .expr + "\n " + opt_node .op .expr
347+ new_variables = (pred_opt_node .op .variables or dict ()).copy ()
348+ new_variables .update (opt_node .op .variables or dict ())
349+
350+ new_op = DataFrameEval (
351+ _key = op .key ,
352+ _output_types = get_output_types (op .outputs [0 ]),
353+ expr = new_expr ,
354+ variables = new_variables ,
355+ parser = "pandas" ,
356+ is_query = False ,
357+ self_target = True ,
358+ )
359+ new_node = new_op .new_tileable (
360+ pred_opt_node .inputs , _key = node .key , _id = node .id , ** node .params
361+ ).data
362+
363+ self ._add_collapsable_predecessor (opt_node , pred_opt_node )
364+ self ._remove_collapsable_predecessors (opt_node )
365+ self ._update_op_node (opt_node , new_node )
0 commit comments