1414from pytensor .link .numba .dispatch .cython_support import wrap_cython_function
1515from pytensor .link .utils import (
1616 get_name_for_object ,
17- unique_name_generator ,
1817)
1918from pytensor .scalar .basic import (
2019 Add ,
@@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
8180 scalar_func_numba = generate_fallback_impl (op , node , ** kwargs )
8281
8382 scalar_op_fn_name = get_name_for_object (scalar_func_numba )
84-
83+ prefix = "x" if scalar_func_name != "x" else "y"
84+ input_names = [f"{ prefix } { i } " for i in range (len (node .inputs ))]
85+ input_signature = ", " .join (input_names )
8586 global_env = {"scalar_func_numba" : scalar_func_numba }
8687
8788 if input_inner_dtypes is None and output_inner_dtype is None :
88- unique_names = unique_name_generator (
89- [scalar_op_fn_name , "scalar_func_numba" ], suffix_sep = "_"
90- )
91- input_names = ", " .join (unique_names (v , force_unique = True ) for v in node .inputs )
9289 if not has_pyx_skip_dispatch :
9390 scalar_op_src = f"""
94- def { scalar_op_fn_name } ({ input_names } ):
95- return scalar_func_numba({ input_names } )
91+ def { scalar_op_fn_name } ({ input_signature } ):
92+ return scalar_func_numba({ input_signature } )
9693 """
9794 else :
9895 scalar_op_src = f"""
99- def { scalar_op_fn_name } ({ input_names } ):
100- return scalar_func_numba({ input_names } , np.intc(1))
96+ def { scalar_op_fn_name } ({ input_signature } ):
97+ return scalar_func_numba({ input_signature } , np.intc(1))
10198 """
10299
103100 else :
@@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}):
108105 for i , i_dtype in enumerate (input_inner_dtypes )
109106 }
110107 global_env .update (input_tmp_dtype_names )
111-
112- unique_names = unique_name_generator (
113- [scalar_op_fn_name , "scalar_func_numba" , * global_env .keys ()],
114- suffix_sep = "_" ,
115- )
116-
117- input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
118108 converted_call_args = ", " .join (
119109 f"direct_cast({ i_name } , { i_tmp_dtype_name } )"
120110 for i_name , i_tmp_dtype_name in zip (
@@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}):
123113 )
124114 if not has_pyx_skip_dispatch :
125115 scalar_op_src = f"""
126- def { scalar_op_fn_name } ({ ", " . join ( input_names ) } ):
116+ def { scalar_op_fn_name } ({ input_signature } ):
127117 return direct_cast(scalar_func_numba({ converted_call_args } ), output_dtype)
128118 """
129119 else :
130120 scalar_op_src = f"""
131- def { scalar_op_fn_name } ({ ", " . join ( input_names ) } ):
121+ def { scalar_op_fn_name } ({ input_signature } ):
132122 return direct_cast(scalar_func_numba({ converted_call_args } , np.intc(1)), output_dtype)
133123 """
134124
135125 scalar_op_fn = compile_numba_function_src (
136126 scalar_op_src ,
137127 scalar_op_fn_name ,
138- { ** globals (), ** global_env } ,
128+ globals () | global_env ,
139129 )
140130
141131 # Functions that call a function pointer can't be cached
@@ -157,8 +147,8 @@ def switch(condition, x, y):
157147
158148def binary_to_nary_func (inputs : list [Variable ], binary_op_name : str , binary_op : str ):
159149 """Create a Numba-compatible N-ary function from a binary function."""
160- unique_names = unique_name_generator ([ "binary_op_name" ], suffix_sep = "_" )
161- input_names = [unique_names ( v , force_unique = True ) for v in inputs ]
150+ var_prefix = "x" if binary_op_name != "x" else "y"
151+ input_names = [f" { var_prefix } { i } " for i in range ( len ( inputs )) ]
162152 input_signature = ", " .join (input_names )
163153 output_expr = binary_op .join (input_names )
164154
0 commit comments