Skip to content

Commit 4787bc9

Browse files
committed
Fix rebase
1 parent 26bc88a commit 4787bc9

File tree

7 files changed

+270
-69
lines changed

7 files changed

+270
-69
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -244,37 +244,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
244244
else:
245245
tensor_inputs = node.inputs[2:]
246246

247-
adv_idxs = [
248-
{
249-
"axis": i,
250-
"dtype": idx.type.dtype,
251-
"bcast": idx.type.broadcastable,
252-
"ndim": idx.type.ndim,
253-
}
254-
for i, idx in enumerate(idxs)
255-
if isinstance(idx.type, TensorType)
256-
]
257-
258247
# Reconstruct indexing information from idx_list and tensor inputs
259-
# basic_idxs = []
260-
# adv_idxs = []
261-
# input_idx = 0
262-
#
263-
# for i, entry in enumerate(op.idx_list):
264-
# if isinstance(entry, slice):
265-
# # Basic slice index
266-
# basic_idxs.append(entry)
267-
# elif isinstance(entry, Type):
268-
# # Advanced tensor index
269-
# if input_idx < len(tensor_inputs):
270-
# idx_input = tensor_inputs[input_idx]
271-
# adv_idxs.append({
272-
# "axis": i,
273-
# "dtype": idx_input.type.dtype,
274-
# "bcast": idx_input.type.broadcastable,
275-
# "ndim": idx_input.type.ndim,
276-
# })
277-
# input_idx += 1
278248
basic_idxs = []
279249
adv_idxs = []
280250
input_idx = 0
@@ -313,7 +283,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
313283
and len(adv_idxs) >= 1
314284
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
315285
# Implementation does not support newaxis
316-
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
286+
and not any(isinstance(idx.type, NoneTypeT) for idx in tensor_inputs)
317287
):
318288
return vector_integer_advanced_indexing(op, node, **kwargs)
319289

pytensor/tensor/rewriting/subtensor.py

Lines changed: 174 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_underlying_scalar_constant_value,
3434
register_infer_shape,
3535
switch,
36+
tile,
3637
)
3738
from pytensor.tensor.basic import constant as tensor_constant
3839
from pytensor.tensor.blockwise import _squeeze_left
@@ -73,7 +74,6 @@
7374
IncSubtensor,
7475
Subtensor,
7576
advanced_inc_subtensor1,
76-
advanced_subtensor,
7777
advanced_subtensor1,
7878
as_index_constant,
7979
get_canonical_form_slice,
@@ -83,7 +83,7 @@
8383
inc_subtensor,
8484
indices_from_subtensor,
8585
)
86-
from pytensor.tensor.type import TensorType
86+
from pytensor.tensor.type import TensorType, integer_dtypes
8787
from pytensor.tensor.type_other import NoneTypeT, SliceType
8888
from pytensor.tensor.variable import TensorConstant, TensorVariable
8989

@@ -256,6 +256,122 @@ def local_replace_AdvancedSubtensor(fgraph, node):
256256
return [new_res]
257257

258258

259+
def _compute_tiling_reps(val, target, allow_symbolic=False, target_shape=None):
260+
"""Compute tiling repetitions needed to broadcast val to match target shape.
261+
262+
Parameters
263+
----------
264+
val : TensorVariable
265+
The value to tile
266+
target : TensorVariable
267+
The target to match shape with (or None if using target_shape)
268+
allow_symbolic : bool
269+
If True, allow symbolic shapes (return reps with 1s, skip tiling)
270+
If False, return None for symbolic shapes
271+
target_shape : tuple, optional
272+
If provided, use this shape tuple instead of target.shape
273+
274+
Returns
275+
-------
276+
tuple or None
277+
(needs_tiling, reps, has_symbolic_shapes) if compatible, None otherwise
278+
"""
279+
try:
280+
needs_tiling = False
281+
reps = []
282+
has_symbolic_shapes = False
283+
284+
def get_target_shape_i(i):
285+
return target.shape[i] if i < len(target.shape) else None
286+
287+
if target_shape is None:
288+
target_ndim = target.ndim
289+
else:
290+
target_ndim = len(target_shape)
291+
292+
for i in range(target_ndim):
293+
try:
294+
target_shape_i = get_target_shape_i(i)
295+
val_shape_i = val.shape[i]
296+
except (IndexError, AttributeError, TypeError):
297+
return None
298+
299+
if target_shape_i is None:
300+
# Symbolic shape in target - allow but skip tiling
301+
reps.append(1)
302+
continue
303+
304+
try:
305+
target_shape_val = get_scalar_constant_value(
306+
target_shape_i, only_process_constants=True
307+
)
308+
val_shape_val = get_scalar_constant_value(
309+
val_shape_i, only_process_constants=True
310+
)
311+
312+
if target_shape_val == val_shape_val:
313+
reps.append(1)
314+
elif val_shape_val == 1:
315+
needs_tiling = True
316+
reps.append(target_shape_i)
317+
else:
318+
return None
319+
320+
except NotScalarConstantError:
321+
has_symbolic_shapes = True
322+
if not allow_symbolic:
323+
return None
324+
# For symbolic shapes, check dimension compatibility
325+
if target_ndim == val.ndim:
326+
reps.append(1)
327+
elif val.ndim == 0:
328+
reps.append(1)
329+
elif val.ndim == 1 and target_ndim >= 1:
330+
reps.append(1)
331+
elif val.ndim < target_ndim:
332+
return None
333+
else:
334+
return None
335+
336+
return (needs_tiling, reps, has_symbolic_shapes)
337+
except (TypeError, ValueError, AttributeError, IndexError):
338+
return None
339+
340+
341+
def _validate_and_apply_tiling(val, reps):
342+
"""Validate that all reps are positive and apply tiling.
343+
344+
Parameters
345+
----------
346+
val : TensorVariable
347+
The value to tile
348+
reps : list
349+
Repetition counts for each dimension
350+
351+
Returns
352+
-------
353+
TensorVariable or None
354+
Tiled value if valid, None otherwise
355+
"""
356+
try:
357+
for rep in reps:
358+
if isinstance(rep, (int, np.integer)):
359+
if rep <= 0:
360+
return None
361+
else:
362+
try:
363+
rep_val = get_scalar_constant_value(
364+
rep, only_process_constants=True
365+
)
366+
if rep_val <= 0:
367+
return None
368+
except NotScalarConstantError:
369+
return None
370+
return tile(val, reps)
371+
except (TypeError, ValueError, AttributeError, IndexError):
372+
return None
373+
374+
259375
@register_specialize
260376
@node_rewriter([AdvancedIncSubtensor])
261377
def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
@@ -265,6 +381,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
265381
"""
266382

267383
if type(node.op) is not AdvancedIncSubtensor:
384+
# Don't apply to subclasses
268385
return
269386

270387
if node.op.ignore_duplicates:
@@ -1321,7 +1438,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13211438
if isinstance(node.op, IncSubtensor):
13221439
xi = Subtensor(node.op.idx_list)(x, *i)
13231440
elif isinstance(node.op, AdvancedIncSubtensor):
1324-
xi = advanced_subtensor(x, *i)
1441+
# Use the same idx_list as the original operation to ensure correct shape
1442+
op = AdvancedSubtensor(node.op.idx_list)
1443+
xi = op.make_node(x, *i).outputs[0]
13251444
elif isinstance(node.op, AdvancedIncSubtensor1):
13261445
xi = advanced_subtensor1(x, *i)
13271446
else:
@@ -1771,10 +1890,11 @@ def local_blockwise_inc_subtensor(fgraph, node):
17711890

17721891

17731892
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
1774-
def bool_idx_to_nonzero(fgraph, node):
1775-
"""Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
1893+
def ravel_multidimensional_bool_idx(fgraph, node):
1894+
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
17761895
1777-
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
1896+
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1897+
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17781898
"""
17791899

17801900
if isinstance(node.op, AdvancedSubtensor):
@@ -1787,26 +1907,53 @@ def bool_idx_to_nonzero(fgraph, node):
17871907
# Reconstruct indices from idx_list and tensor inputs
17881908
idxs = indices_from_subtensor(tensor_inputs, node.op.idx_list)
17891909

1790-
bool_pos = {
1791-
i
1910+
if any(
1911+
(
1912+
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
1913+
or isinstance(idx.type, NoneTypeT)
1914+
)
1915+
for idx in idxs
1916+
):
1917+
# Get out if there are any other advanced indexes or np.newaxis
1918+
return None
1919+
1920+
bool_idxs = [
1921+
(i, idx)
17921922
for i, idx in enumerate(idxs)
17931923
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
1794-
}
1924+
]
17951925

1796-
if not bool_pos:
1926+
if len(bool_idxs) != 1:
1927+
# Get out if there are no or multiple boolean idxs
1928+
return None
1929+
[(bool_idx_pos, bool_idx)] = bool_idxs
1930+
bool_idx_ndim = bool_idx.type.ndim
1931+
if bool_idx.type.ndim < 2:
1932+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
17971933
return None
17981934

1799-
new_idxs = []
1800-
for i, idx in enumerate(idxs):
1801-
if i in bool_pos:
1802-
new_idxs.extend(idx.nonzero())
1803-
else:
1804-
new_idxs.append(idx)
1935+
x_shape = x.shape
1936+
raveled_x = x.reshape(
1937+
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
1938+
)
1939+
1940+
raveled_bool_idx = bool_idx.ravel()
1941+
new_idxs = list(idxs)
1942+
new_idxs[bool_idx_pos] = raveled_bool_idx
18051943

18061944
if isinstance(node.op, AdvancedSubtensor):
1807-
new_out = node.op(x, *new_idxs)
1945+
new_out = raveled_x[tuple(new_idxs)]
18081946
else:
1809-
new_out = node.op(x, y, *new_idxs)
1947+
sub = raveled_x[tuple(new_idxs)]
1948+
new_out = inc_subtensor(
1949+
sub,
1950+
y,
1951+
set_instead_of_inc=node.op.set_instead_of_inc,
1952+
ignore_duplicates=node.op.ignore_duplicates,
1953+
inplace=node.op.inplace,
1954+
)
1955+
new_out = new_out.reshape(x_shape)
1956+
18101957
return [copy_stack_trace(node.outputs[0], new_out)]
18111958

18121959

@@ -1941,10 +2088,16 @@ def ravel_multidimensional_int_idx(fgraph, node):
19412088

19422089

19432090
optdb["specialize"].register(
1944-
bool_idx_to_nonzero.__name__,
1945-
bool_idx_to_nonzero,
2091+
ravel_multidimensional_bool_idx.__name__,
2092+
ravel_multidimensional_bool_idx,
2093+
"numba",
2094+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
2095+
)
2096+
2097+
optdb["specialize"].register(
2098+
ravel_multidimensional_int_idx.__name__,
2099+
ravel_multidimensional_int_idx,
19462100
"numba",
1947-
"shape_unsafe", # It can mask invalid mask sizes
19482101
use_db_name_as_tag=False, # Not included if only "specialize" is requested
19492102
)
19502103

pytensor/tensor/subtensor.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,12 @@ def __init__(self, idx_list=None):
922922

923923
def _normalize_idx_list_for_hash(self):
924924
"""Normalize idx_list for hash and equality comparison."""
925-
if self.idx_list is None:
925+
idx_list = getattr(self, "idx_list", None)
926+
if idx_list is None:
926927
return None
927928

928929
msg = []
929-
for entry in self.idx_list:
930+
for entry in idx_list:
930931
if isinstance(entry, slice):
931932
msg.append((entry.start, entry.stop, entry.step))
932933
else:
@@ -2812,13 +2813,6 @@ def make_node(self, x, *inputs):
28122813
advanced_indices = []
28132814
adv_group_axis = None
28142815
last_adv_group_axis = None
2815-
if new_axes: #not defined?
2816-
expanded_x_shape_list = list(x.type.shape)
2817-
for new_axis in new_axes:
2818-
expanded_x_shape_list.insert(new_axis, 1)
2819-
expanded_x_shape = tuple(expanded_x_shape_list)
2820-
else:
2821-
expanded_x_shape = x.type.shape
28222816
for i, (idx, dim_length) in enumerate(
28232817
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None))
28242818
):

0 commit comments

Comments
 (0)