Skip to content

Commit b6c3199

Browse files
committed
Fix cache of default subtensor
Implementation was specializing on node repeated inputs an `unique_names` would return the same name for repeated inputs. The cache key didn't account for this. We also don't want to compile different functions for different patterns of repeated inputs as it doesn't translate to an obvious handle for the compiler to specialize upon. We we wanted to inline constants that may make more sense.
1 parent b5cddf9 commit b6c3199

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
register_funcify_and_cache_key,
1919
register_funcify_default_op_cache_key,
2020
)
21-
from pytensor.link.utils import unique_name_generator
2221
from pytensor.tensor import TensorType
2322
from pytensor.tensor.rewriting.subtensor import is_full_slice
2423
from pytensor.tensor.subtensor import (
@@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields):
143142
def numba_funcify_default_subtensor(op, node, **kwargs):
144143
"""Create a Python function that assembles and uses an index on an array."""
145144

146-
unique_names = unique_name_generator(
147-
["subtensor", "incsubtensor", "z"], suffix_sep="_"
148-
)
149-
150-
def convert_indices(indices, entry):
151-
if indices and isinstance(entry, Type):
152-
rval = indices.pop(0)
153-
return unique_names(rval)
145+
def convert_indices(indice_names, entry):
146+
if indice_names and isinstance(entry, Type):
147+
return next(indice_names)
154148
elif isinstance(entry, slice):
155149
return (
156-
f"slice({convert_indices(indices, entry.start)}, "
157-
f"{convert_indices(indices, entry.stop)}, "
158-
f"{convert_indices(indices, entry.step)})"
150+
f"slice({convert_indices(indice_names, entry.start)}, "
151+
f"{convert_indices(indice_names, entry.stop)}, "
152+
f"{convert_indices(indice_names, entry.step)})"
159153
)
160154
elif isinstance(entry, type(None)):
161155
return "None"
@@ -166,13 +160,15 @@ def convert_indices(indices, entry):
166160
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
167161
)
168162
index_start_idx = 1 + int(set_or_inc)
169-
170-
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
171163
op_indices = list(node.inputs[index_start_idx:])
172164
idx_list = getattr(op, "idx_list", None)
165+
idx_names = [f"idx_{i}" for i in range(len(op_indices))]
173166

167+
input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names]
168+
169+
idx_names_iterator = iter(idx_names)
174170
indices_creation_src = (
175-
tuple(convert_indices(op_indices, idx) for idx in idx_list)
171+
tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list)
176172
if idx_list
177173
else tuple(input_names[index_start_idx:])
178174
)
@@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}):
220216
function_name=function_name,
221217
global_env=globals() | {"np": np},
222218
)
223-
cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor")
219+
cache_key = subtensor_op_cache_key(
220+
op, func="numba_funcify_default_subtensor", version=1
221+
)
224222
return numba_basic.numba_njit(func, boundscheck=True), cache_key
225223

226224

0 commit comments

Comments
 (0)