1010 register_funcify_and_cache_key ,
1111 register_funcify_default_op_cache_key ,
1212)
13- from pytensor .link .utils import unique_name_generator
1413from pytensor .tensor .basic import (
1514 Alloc ,
1615 AllocEmpty ,
2827
2928@register_funcify_default_op_cache_key (AllocEmpty )
3029def numba_funcify_AllocEmpty (op , node , ** kwargs ):
31- global_env = {
32- "np" : np ,
33- "dtype" : np .dtype (op .dtype ),
34- }
35-
36- unique_names = unique_name_generator (
37- ["np" , "dtype" , "allocempty" , "scalar_shape" ], suffix_sep = "_"
38- )
39- shape_var_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
30+ shape_var_names = [f"sh{ i } " for i in range (len (node .inputs ))]
4031 shape_var_item_names = [f"{ name } _item" for name in shape_var_names ]
4132 shapes_to_items_src = indent (
4233 "\n " .join (
@@ -56,21 +47,15 @@ def allocempty({", ".join(shape_var_names)}):
5647 """
5748
5849 alloc_fn = compile_numba_function_src (
59- alloc_def_src , "allocempty" , { ** globals (), ** global_env }
50+ alloc_def_src , "allocempty" , globals () | { "np" : np , "dtype" : np . dtype ( op . dtype ) }
6051 )
6152
6253 return numba_basic .numba_njit (alloc_fn )
6354
6455
6556@register_funcify_and_cache_key (Alloc )
6657def numba_funcify_Alloc (op , node , ** kwargs ):
67- global_env = {"np" : np }
68-
69- unique_names = unique_name_generator (
70- ["np" , "alloc" , "val_np" , "val" , "scalar_shape" , "res" ],
71- suffix_sep = "_" ,
72- )
73- shape_var_names = [unique_names (v , force_unique = True ) for v in node .inputs [1 :]]
58+ shape_var_names = [f"sh{ i } " for i in range (len (node .inputs ) - 1 )]
7459 shape_var_item_names = [f"{ name } _item" for name in shape_var_names ]
7560 shapes_to_items_src = indent (
7661 "\n " .join (
@@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}):
10287 alloc_fn = compile_numba_function_src (
10388 alloc_def_src ,
10489 "alloc" ,
105- { ** globals (), ** global_env },
90+ globals () | { "np" : np },
10691 )
10792
10893 cache_key = sha256 (
@@ -207,14 +192,7 @@ def eye(N, M, k):
207192@register_funcify_default_op_cache_key (MakeVector )
208193def numba_funcify_MakeVector (op , node , ** kwargs ):
209194 dtype = np .dtype (op .dtype )
210-
211- global_env = {"np" : np , "dtype" : dtype }
212-
213- unique_names = unique_name_generator (
214- ["np" ],
215- suffix_sep = "_" ,
216- )
217- input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
195+ input_names = [f"x{ i } " for i in range (len (node .inputs ))]
218196
219197 def create_list_string (x ):
220198 args = ", " .join ([f"{ i } .item()" for i in x ] + (["" ] if len (x ) == 1 else []))
@@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}):
228206 makevector_fn = compile_numba_function_src (
229207 makevector_def_src ,
230208 "makevector" ,
231- { ** globals (), ** global_env },
209+ globals () | { "np" : np , "dtype" : dtype },
232210 )
233211
234212 return numba_basic .numba_njit (makevector_fn )
0 commit comments