1818 register_funcify_and_cache_key ,
1919 register_funcify_default_op_cache_key ,
2020)
21- from pytensor .link .utils import unique_name_generator
2221from pytensor .tensor import TensorType
2322from pytensor .tensor .rewriting .subtensor import is_full_slice
2423from pytensor .tensor .subtensor import (
@@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields):
143142def numba_funcify_default_subtensor (op , node , ** kwargs ):
144143 """Create a Python function that assembles and uses an index on an array."""
145144
146- unique_names = unique_name_generator (
147- ["subtensor" , "incsubtensor" , "z" ], suffix_sep = "_"
148- )
149-
150- def convert_indices (indices , entry ):
151- if indices and isinstance (entry , Type ):
152- rval = indices .pop (0 )
153- return unique_names (rval )
145+ def convert_indices (indice_names , entry ):
146+ if indice_names and isinstance (entry , Type ):
147+ return next (indice_names )
154148 elif isinstance (entry , slice ):
155149 return (
156- f"slice({ convert_indices (indices , entry .start )} , "
157- f"{ convert_indices (indices , entry .stop )} , "
158- f"{ convert_indices (indices , entry .step )} )"
150+ f"slice({ convert_indices (indice_names , entry .start )} , "
151+ f"{ convert_indices (indice_names , entry .stop )} , "
152+ f"{ convert_indices (indice_names , entry .step )} )"
159153 )
160154 elif isinstance (entry , type (None )):
161155 return "None"
@@ -166,13 +160,15 @@ def convert_indices(indices, entry):
166160 op , IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
167161 )
168162 index_start_idx = 1 + int (set_or_inc )
169-
170- input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
171163 op_indices = list (node .inputs [index_start_idx :])
172164 idx_list = getattr (op , "idx_list" , None )
165+ idx_names = [f"idx_{ i } " for i in range (len (op_indices ))]
173166
167+ input_names = ["x" , "y" , * idx_names ] if set_or_inc else ["x" , * idx_names ]
168+
169+ idx_names_iterator = iter (idx_names )
174170 indices_creation_src = (
175- tuple (convert_indices (op_indices , idx ) for idx in idx_list )
171+ tuple (convert_indices (idx_names_iterator , idx ) for idx in idx_list )
176172 if idx_list
177173 else tuple (input_names [index_start_idx :])
178174 )
@@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}):
220216 function_name = function_name ,
221217 global_env = globals () | {"np" : np },
222218 )
223- cache_key = subtensor_op_cache_key (op , func = "numba_funcify_default_subtensor" )
219+ cache_key = subtensor_op_cache_key (
220+ op , func = "numba_funcify_default_subtensor" , version = 1
221+ )
224222 return numba_basic .numba_njit (func , boundscheck = True ), cache_key
225223
226224
0 commit comments