@@ -860,12 +860,12 @@ def add_function(self, func, ndim, cols=None, func_name=None):
860860 self ._output_key_to_post_steps [step .output_key ] = step
861861 self ._update_col_dict (self ._output_key_to_post_cols , step .output_key , cols )
862862
863- @functools .lru_cache (100 )
864- def _compile_expr_function (self , py_src ):
863+ def _compile_expr_function (self , py_src : str , local_consts : dict ):
865864 from ... import tensor , dataframe
866865
867866 result_store = dict ()
868- global_vars = globals ()
867+ global_vars = globals ().copy ()
868+ global_vars .update (local_consts )
869869 global_vars .update (dict (mt = tensor , md = dataframe , array = np .array , nan = np .nan ))
870870 exec (
871871 py_src , global_vars , result_store
@@ -989,24 +989,24 @@ def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
989989 assert len (initial_inputs ) == 1
990990 input_key = initial_inputs [0 ].key
991991
992- func_str , _ = self ._generate_function_str (t .inputs [0 ])
992+ func_str , _ , local_consts = self ._generate_function_str (t .inputs [0 ])
993993 pre_funcs .append (
994994 ReductionPreStep (
995995 input_key ,
996996 agg_input_key ,
997997 None ,
998- self ._compile_expr_function (func_str ),
998+ self ._compile_expr_function (func_str , local_consts ),
999999 )
10001000 )
10011001 # collect function output after agg
1002- func_str , input_keys = self ._generate_function_str (func_ret )
1002+ func_str , input_keys , local_consts = self ._generate_function_str (func_ret )
10031003 post_funcs .append (
10041004 ReductionPostStep (
10051005 input_keys ,
10061006 func_ret .key ,
10071007 func_name ,
10081008 None ,
1009- self ._compile_expr_function (func_str ),
1009+ self ._compile_expr_function (func_str , local_consts ),
10101010 )
10111011 )
10121012 if len (_func_compile_cache ) > 100 : # pragma: no cover
@@ -1034,6 +1034,7 @@ def _generate_function_str(self, out_tileable):
10341034
10351035 input_key_to_var = OrderedDict ()
10361036 local_key_to_var = dict ()
1037+ local_consts_to_val = dict ()
10371038 ref_counts = dict ()
10381039 ref_visited = set ()
10391040 local_lines = []
@@ -1086,7 +1087,12 @@ def _interpret_var(v):
10861087 # get representation for variables
10871088 if hasattr (v , "key" ):
10881089 return keys_to_vars [v .key ]
1089- return v
1090+ elif isinstance (v , (int , bool , str , bytes , np .integer , np .bool_ )):
1091+ return repr (v )
1092+ else :
1093+ const_name = f"_const_{ len (local_consts_to_val )} "
1094+ local_consts_to_val [const_name ] = v
1095+ return const_name
10901096
10911097 func_name = func_name_raw = getattr (t .op , "_func_name" , None )
10921098 rfunc_name = getattr (t .op , "_rfunc_name" , func_name )
@@ -1187,6 +1193,7 @@ def _interpret_var(v):
11871193 f" { lines_str } \n "
11881194 f" return { local_key_to_var [out_tileable .key ]} " ,
11891195 list (input_key_to_var .keys ()),
1196+ local_consts_to_val ,
11901197 )
11911198
11921199 def compile (self ) -> ReductionSteps :
0 commit comments