1313# limitations under the License.
1414
1515import weakref
16- from typing import Optional , Tuple
16+ from typing import NamedTuple , Optional
1717
18+ import numpy as np
1819from pandas .api .types import is_scalar
1920
2021from .... import dataframe as md
2122from ....core import Tileable , get_output_types , ENTITY_TYPE
2223from ....dataframe .arithmetic .core import DataFrameUnaryUfunc , DataFrameBinopUfunc
2324from ....dataframe .base .eval import DataFrameEval
2425from ....dataframe .indexing .getitem import DataFrameIndex
26+ from ....dataframe .indexing .setitem import DataFrameSetitem
2527from ....typing import OperandType
2628from ....utils import implements
2729from ..core import OptimizationRecord , OptimizationRecordType
2830from ..tileable .core import register_tileable_optimization_rule
2931from .core import OptimizationRule
3032
3133
34+ class EvalExtractRecord (NamedTuple ):
35+ tileable : Optional [Tileable ] = None
36+ expr : Optional [str ] = None
37+ variables : Optional [dict ] = None
38+
39+
3240def _get_binop_builder (op_str : str ):
3341 def builder (lhs : str , rhs : str ):
3442 return f"({ lhs } ) { op_str } ({ rhs } )"
@@ -60,9 +68,16 @@ def builder(lhs: str, rhs: str):
6068
6169@register_tileable_optimization_rule ([DataFrameUnaryUfunc , DataFrameBinopUfunc ])
6270class SeriesArithmeticToEval (OptimizationRule ):
71+ _var_counter = 0
72+
73+ @classmethod
74+ def _next_var_id (cls ):
75+ cls ._var_counter += 1
76+ return cls ._var_counter
77+
6378 @implements (OptimizationRule .match )
6479 def match (self , op : OperandType ) -> bool :
65- _ , expr = self ._extract_eval_expression (op .outputs [0 ])
80+ _ , expr , _ = self ._extract_eval_expression (op .outputs [0 ])
6681 return expr is not None
6782
6883 @staticmethod
@@ -91,14 +106,17 @@ def _is_select_dataframe_column(tileable) -> bool:
91106 )
92107
93108 @classmethod
94- def _extract_eval_expression (
95- cls , tileable
96- ) -> Tuple [Optional [Tileable ], Optional [str ]]:
109+ def _extract_eval_expression (cls , tileable ) -> EvalExtractRecord :
97110 if is_scalar (tileable ):
98- return None , repr (tileable )
111+ if isinstance (tileable , (int , bool , str , bytes , np .integer , np .bool_ )):
112+ return EvalExtractRecord (expr = repr (tileable ))
113+ else :
114+ var_name = f"__eval_scalar_var{ cls ._next_var_id ()} "
115+ var_dict = {var_name : tileable }
116+ return EvalExtractRecord (expr = f"@{ var_name } " , variables = var_dict )
99117
100118 if not isinstance (tileable , ENTITY_TYPE ): # pragma: no cover
101- return None , None
119+ return EvalExtractRecord ()
102120
103121 if tileable in _extract_result_cache :
104122 return _extract_result_cache [tileable ]
@@ -109,69 +127,75 @@ def _extract_eval_expression(
109127 result = cls ._extract_unary (tileable )
110128 elif isinstance (tileable .op , DataFrameBinopUfunc ):
111129 if tileable .op .fill_value is not None or tileable .op .level is not None :
112- result = None , None
130+ result = EvalExtractRecord ()
113131 else :
114132 result = cls ._extract_binary (tileable )
115133 else :
116- result = None , None
134+ result = EvalExtractRecord ()
117135
118136 _extract_result_cache [tileable ] = result
119137 return result
120138
121139 @classmethod
122- def _extract_column_select (
123- cls , tileable
124- ) -> Tuple [Optional [Tileable ], Optional [str ]]:
125- return tileable .inputs [0 ], f"`{ tileable .op .col_names } `"
140+ def _extract_column_select (cls , tileable ) -> EvalExtractRecord :
141+ return EvalExtractRecord (tileable .inputs [0 ], f"`{ tileable .op .col_names } `" )
126142
127143 @classmethod
128- def _extract_unary (cls , tileable ) -> Tuple [ Optional [ Tileable ], Optional [ str ]] :
144+ def _extract_unary (cls , tileable ) -> EvalExtractRecord :
129145 op = tileable .op
130146 func_name = getattr (op , "_func_name" ) or getattr (op , "_bin_func_name" )
131147 if func_name not in _func_name_to_builder : # pragma: no cover
132- return None , None
148+ return EvalExtractRecord ()
133149
134- in_tileable , expr = cls ._extract_eval_expression (op .inputs [0 ])
150+ in_tileable , expr , variables = cls ._extract_eval_expression (op .inputs [0 ])
135151 if in_tileable is None :
136- return None , None
152+ return EvalExtractRecord ()
137153
138154 cls ._add_collapsable_predecessor (tileable , op .inputs [0 ])
139- return in_tileable , _func_name_to_builder [func_name ](expr )
155+ return EvalExtractRecord (
156+ in_tileable , _func_name_to_builder [func_name ](expr ), variables
157+ )
140158
141159 @classmethod
142- def _extract_binary (cls , tileable ) -> Tuple [ Optional [ Tileable ], Optional [ str ]] :
160+ def _extract_binary (cls , tileable ) -> EvalExtractRecord :
143161 op = tileable .op
144162 func_name = getattr (op , "_func_name" , None ) or getattr (op , "_bit_func_name" )
145163 if func_name not in _func_name_to_builder : # pragma: no cover
146- return None , None
164+ return EvalExtractRecord ()
147165
148- lhs_tileable , lhs_expr = cls ._extract_eval_expression (op .lhs )
166+ lhs_tileable , lhs_expr , lhs_vars = cls ._extract_eval_expression (op .lhs )
149167 if lhs_tileable is not None :
150168 cls ._add_collapsable_predecessor (tileable , op .lhs )
151- rhs_tileable , rhs_expr = cls ._extract_eval_expression (op .rhs )
169+ rhs_tileable , rhs_expr , rhs_vars = cls ._extract_eval_expression (op .rhs )
152170 if rhs_tileable is not None :
153171 cls ._add_collapsable_predecessor (tileable , op .rhs )
154172
155173 if lhs_expr is None or rhs_expr is None :
156- return None , None
174+ return EvalExtractRecord ()
157175 if (
158176 lhs_tileable is not None
159177 and rhs_tileable is not None
160178 and lhs_tileable .key != rhs_tileable .key
161179 ):
162- return None , None
180+ return EvalExtractRecord ()
181+
182+ variables = (lhs_vars or dict ()).copy ()
183+ variables .update (rhs_vars or dict ())
163184 in_tileable = next (t for t in [lhs_tileable , rhs_tileable ] if t is not None )
164- return in_tileable , _func_name_to_builder [func_name ](lhs_expr , rhs_expr )
185+ return EvalExtractRecord (
186+ in_tileable , _func_name_to_builder [func_name ](lhs_expr , rhs_expr ), variables
187+ )
165188
166189 @implements (OptimizationRule .apply )
167190 def apply (self , op : OperandType ):
168191 node = op .outputs [0 ]
169- in_tileable , expr = self ._extract_eval_expression (node )
192+ in_tileable , expr , variables = self ._extract_eval_expression (node )
170193
171194 new_op = DataFrameEval (
172195 _key = node .op .key ,
173196 _output_types = get_output_types (node ),
174197 expr = expr ,
198+ variables = variables or dict (),
175199 parser = "pandas" ,
176200 is_query = False ,
177201 )
@@ -195,39 +219,41 @@ def apply(self, op: OperandType):
195219 pass
196220
197221
198- @ register_tileable_optimization_rule ([ DataFrameIndex ])
199- class DataFrameBoolEvalToQuery ( OptimizationRule ) :
200- def match ( self , op : "DataFrameIndex" ) -> bool :
222+ class _DataFrameEvalRewriteRule ( OptimizationRule ):
223+ def match ( self , op : OperandType ) -> bool :
224+ optimized_eval_op = self . _get_optimized_eval_op ( op )
201225 if (
202- op . col_names is not None
203- or not isinstance ( op . mask , md . Series )
204- or op . mask . dtype != bool
226+ not isinstance ( optimized_eval_op , DataFrameEval )
227+ or optimized_eval_op . is_query
228+ or optimized_eval_op . inputs [ 0 ]. key != op . inputs [ 0 ]. key
205229 ):
206230 return False
207- optimized = self ._records .get_optimization_result (op .mask )
208- mask_op = optimized .op if optimized is not None else op .mask .op
209- if not isinstance (mask_op , DataFrameEval ) or mask_op .is_query :
210- return False
211231 return True
212232
213- def apply (self , op : "DataFrameIndex" ):
233+ def _build_new_eval_op (self , op : OperandType ):
234+ raise NotImplementedError
235+
236+ def _get_optimized_eval_op (self , op : OperandType ) -> OperandType :
237+ in_columnar_node = self ._get_input_columnar_node (op )
238+ optimized = self ._records .get_optimization_result (in_columnar_node )
239+ return optimized .op if optimized is not None else in_columnar_node .op
240+
241+ def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
242+ raise NotImplementedError
243+
244+ def apply (self , op : DataFrameIndex ):
214245 node = op .outputs [0 ]
215246 in_tileable = op .inputs [0 ]
247+ in_columnar_node = self ._get_input_columnar_node (op )
248+
249+ new_op = self ._build_new_eval_op (op )
250+ new_op ._key = node .op .key
216251
217- optimized = self ._records .get_optimization_result (op .mask )
218- mask_op = optimized .op if optimized is not None else op .mask .op
219- new_op = DataFrameEval (
220- _key = node .op .key ,
221- _output_types = get_output_types (node ),
222- expr = mask_op .expr ,
223- parser = "pandas" ,
224- is_query = True ,
225- )
226252 new_node = new_op .new_tileable ([in_tileable ], ** node .params ).data
227253 new_node ._key = node .key
228254 new_node ._id = node .id
229255
230- self ._add_collapsable_predecessor (node , op . mask )
256+ self ._add_collapsable_predecessor (node , in_columnar_node )
231257 self ._remove_collapsable_predecessors (node )
232258
233259 self ._replace_node (node , new_node )
@@ -242,3 +268,50 @@ def apply(self, op: "DataFrameIndex"):
242268 self ._graph .results [i ] = new_node
243269 except ValueError :
244270 pass
271+
272+
273+ @register_tileable_optimization_rule ([DataFrameIndex ])
274+ class DataFrameBoolEvalToQuery (_DataFrameEvalRewriteRule ):
275+ def match (self , op : DataFrameIndex ) -> bool :
276+ if (
277+ op .col_names is not None
278+ or not isinstance (op .mask , md .Series )
279+ or op .mask .dtype != bool
280+ ):
281+ return False
282+ return super ().match (op )
283+
284+ def _get_input_columnar_node (self , op : OperandType ) -> ENTITY_TYPE :
285+ return op .mask
286+
287+ def _build_new_eval_op (self , op : OperandType ):
288+ in_eval_op = self ._get_optimized_eval_op (op )
289+ return DataFrameEval (
290+ _output_types = get_output_types (op .outputs [0 ]),
291+ expr = in_eval_op .expr ,
292+ variables = in_eval_op .variables ,
293+ parser = "pandas" ,
294+ is_query = True ,
295+ )
296+
297+
298+ @register_tileable_optimization_rule ([DataFrameSetitem ])
299+ class DataFrameEvalSetItemToEval (_DataFrameEvalRewriteRule ):
300+ def match (self , op : DataFrameSetitem ):
301+ if not isinstance (op .indexes , str ) or not isinstance (op .value , md .Series ):
302+ return False
303+ return super ().match (op )
304+
305+ def _get_input_columnar_node (self , op : DataFrameSetitem ) -> ENTITY_TYPE :
306+ return op .value
307+
308+ def _build_new_eval_op (self , op : DataFrameSetitem ):
309+ in_eval_op = self ._get_optimized_eval_op (op )
310+ return DataFrameEval (
311+ _output_types = get_output_types (op .outputs [0 ]),
312+ expr = f"`{ op .indexes } ` = { in_eval_op .expr } " ,
313+ variables = in_eval_op .variables ,
314+ parser = "pandas" ,
315+ is_query = False ,
316+ self_target = True ,
317+ )
0 commit comments