3333 get_underlying_scalar_constant_value ,
3434 register_infer_shape ,
3535 switch ,
36+ tile ,
3637)
3738from pytensor .tensor .basic import constant as tensor_constant
3839from pytensor .tensor .blockwise import _squeeze_left
7374 IncSubtensor ,
7475 Subtensor ,
7576 advanced_inc_subtensor1 ,
76- advanced_subtensor ,
7777 advanced_subtensor1 ,
7878 as_index_constant ,
7979 get_canonical_form_slice ,
8383 inc_subtensor ,
8484 indices_from_subtensor ,
8585)
86- from pytensor .tensor .type import TensorType
86+ from pytensor .tensor .type import TensorType , integer_dtypes
8787from pytensor .tensor .type_other import NoneTypeT , SliceType
8888from pytensor .tensor .variable import TensorConstant , TensorVariable
8989
@@ -256,6 +256,122 @@ def local_replace_AdvancedSubtensor(fgraph, node):
256256 return [new_res ]
257257
258258
259+ def _compute_tiling_reps (val , target , allow_symbolic = False , target_shape = None ):
260+ """Compute tiling repetitions needed to broadcast val to match target shape.
261+
262+ Parameters
263+ ----------
264+ val : TensorVariable
265+ The value to tile
266+ target : TensorVariable
267+ The target to match shape with (or None if using target_shape)
268+ allow_symbolic : bool
269+ If True, allow symbolic shapes (return reps with 1s, skip tiling)
270+ If False, return None for symbolic shapes
271+ target_shape : tuple, optional
272+ If provided, use this shape tuple instead of target.shape
273+
274+ Returns
275+ -------
276+ tuple or None
277+ (needs_tiling, reps, has_symbolic_shapes) if compatible, None otherwise
278+ """
279+ try :
280+ needs_tiling = False
281+ reps = []
282+ has_symbolic_shapes = False
283+
284+ def get_target_shape_i (i ):
285+ return target .shape [i ] if i < len (target .shape ) else None
286+
287+ if target_shape is None :
288+ target_ndim = target .ndim
289+ else :
290+ target_ndim = len (target_shape )
291+
292+ for i in range (target_ndim ):
293+ try :
294+ target_shape_i = get_target_shape_i (i )
295+ val_shape_i = val .shape [i ]
296+ except (IndexError , AttributeError , TypeError ):
297+ return None
298+
299+ if target_shape_i is None :
300+ # Symbolic shape in target - allow but skip tiling
301+ reps .append (1 )
302+ continue
303+
304+ try :
305+ target_shape_val = get_scalar_constant_value (
306+ target_shape_i , only_process_constants = True
307+ )
308+ val_shape_val = get_scalar_constant_value (
309+ val_shape_i , only_process_constants = True
310+ )
311+
312+ if target_shape_val == val_shape_val :
313+ reps .append (1 )
314+ elif val_shape_val == 1 :
315+ needs_tiling = True
316+ reps .append (target_shape_i )
317+ else :
318+ return None
319+
320+ except NotScalarConstantError :
321+ has_symbolic_shapes = True
322+ if not allow_symbolic :
323+ return None
324+ # For symbolic shapes, check dimension compatibility
325+ if target_ndim == val .ndim :
326+ reps .append (1 )
327+ elif val .ndim == 0 :
328+ reps .append (1 )
329+ elif val .ndim == 1 and target_ndim >= 1 :
330+ reps .append (1 )
331+ elif val .ndim < target_ndim :
332+ return None
333+ else :
334+ return None
335+
336+ return (needs_tiling , reps , has_symbolic_shapes )
337+ except (TypeError , ValueError , AttributeError , IndexError ):
338+ return None
339+
340+
341+ def _validate_and_apply_tiling (val , reps ):
342+ """Validate that all reps are positive and apply tiling.
343+
344+ Parameters
345+ ----------
346+ val : TensorVariable
347+ The value to tile
348+ reps : list
349+ Repetition counts for each dimension
350+
351+ Returns
352+ -------
353+ TensorVariable or None
354+ Tiled value if valid, None otherwise
355+ """
356+ try :
357+ for rep in reps :
358+ if isinstance (rep , (int , np .integer )):
359+ if rep <= 0 :
360+ return None
361+ else :
362+ try :
363+ rep_val = get_scalar_constant_value (
364+ rep , only_process_constants = True
365+ )
366+ if rep_val <= 0 :
367+ return None
368+ except NotScalarConstantError :
369+ return None
370+ return tile (val , reps )
371+ except (TypeError , ValueError , AttributeError , IndexError ):
372+ return None
373+
374+
259375@register_specialize
260376@node_rewriter ([AdvancedIncSubtensor ])
261377def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 (fgraph , node ):
@@ -265,6 +381,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
265381 """
266382
267383 if type (node .op ) is not AdvancedIncSubtensor :
384+ # Don't apply to subclasses
268385 return
269386
270387 if node .op .ignore_duplicates :
@@ -1321,7 +1438,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13211438 if isinstance (node .op , IncSubtensor ):
13221439 xi = Subtensor (node .op .idx_list )(x , * i )
13231440 elif isinstance (node .op , AdvancedIncSubtensor ):
1324- xi = advanced_subtensor (x , * i )
1441+ # Use the same idx_list as the original operation to ensure correct shape
1442+ op = AdvancedSubtensor (node .op .idx_list )
1443+ xi = op .make_node (x , * i ).outputs [0 ]
13251444 elif isinstance (node .op , AdvancedIncSubtensor1 ):
13261445 xi = advanced_subtensor1 (x , * i )
13271446 else :
@@ -1771,10 +1890,11 @@ def local_blockwise_inc_subtensor(fgraph, node):
17711890
17721891
17731892@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
1774- def bool_idx_to_nonzero (fgraph , node ):
1775- """Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
1893+ def ravel_multidimensional_bool_idx (fgraph , node ):
1894+ """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
17761895
1777- x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
1896+ x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1897+ x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17781898 """
17791899
17801900 if isinstance (node .op , AdvancedSubtensor ):
@@ -1787,26 +1907,53 @@ def bool_idx_to_nonzero(fgraph, node):
17871907 # Reconstruct indices from idx_list and tensor inputs
17881908 idxs = indices_from_subtensor (tensor_inputs , node .op .idx_list )
17891909
1790- bool_pos = {
1791- i
1910+ if any (
1911+ (
1912+ (isinstance (idx .type , TensorType ) and idx .type .dtype in integer_dtypes )
1913+ or isinstance (idx .type , NoneTypeT )
1914+ )
1915+ for idx in idxs
1916+ ):
1917+ # Get out if there are any other advanced indexes or np.newaxis
1918+ return None
1919+
1920+ bool_idxs = [
1921+ (i , idx )
17921922 for i , idx in enumerate (idxs )
17931923 if (isinstance (idx .type , TensorType ) and idx .dtype == "bool" )
1794- }
1924+ ]
17951925
1796- if not bool_pos :
1926+ if len (bool_idxs ) != 1 :
1927+ # Get out if there are no or multiple boolean idxs
1928+ return None
1929+ [(bool_idx_pos , bool_idx )] = bool_idxs
1930+ bool_idx_ndim = bool_idx .type .ndim
1931+ if bool_idx .type .ndim < 2 :
1932+ # No need to do anything if it's a vector or scalar, as it's already supported by Numba
17971933 return None
17981934
1799- new_idxs = []
1800- for i , idx in enumerate (idxs ):
1801- if i in bool_pos :
1802- new_idxs .extend (idx .nonzero ())
1803- else :
1804- new_idxs .append (idx )
1935+ x_shape = x .shape
1936+ raveled_x = x .reshape (
1937+ (* x_shape [:bool_idx_pos ], - 1 , * x_shape [bool_idx_pos + bool_idx_ndim :])
1938+ )
1939+
1940+ raveled_bool_idx = bool_idx .ravel ()
1941+ new_idxs = list (idxs )
1942+ new_idxs [bool_idx_pos ] = raveled_bool_idx
18051943
18061944 if isinstance (node .op , AdvancedSubtensor ):
1807- new_out = node . op ( x , * new_idxs )
1945+ new_out = raveled_x [ tuple ( new_idxs )]
18081946 else :
1809- new_out = node .op (x , y , * new_idxs )
1947+ sub = raveled_x [tuple (new_idxs )]
1948+ new_out = inc_subtensor (
1949+ sub ,
1950+ y ,
1951+ set_instead_of_inc = node .op .set_instead_of_inc ,
1952+ ignore_duplicates = node .op .ignore_duplicates ,
1953+ inplace = node .op .inplace ,
1954+ )
1955+ new_out = new_out .reshape (x_shape )
1956+
18101957 return [copy_stack_trace (node .outputs [0 ], new_out )]
18111958
18121959
@@ -1941,10 +2088,16 @@ def ravel_multidimensional_int_idx(fgraph, node):
19412088
19422089
19432090optdb ["specialize" ].register (
1944- bool_idx_to_nonzero .__name__ ,
1945- bool_idx_to_nonzero ,
2091+ ravel_multidimensional_bool_idx .__name__ ,
2092+ ravel_multidimensional_bool_idx ,
2093+ "numba" ,
2094+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
2095+ )
2096+
2097+ optdb ["specialize" ].register (
2098+ ravel_multidimensional_int_idx .__name__ ,
2099+ ravel_multidimensional_int_idx ,
19462100 "numba" ,
1947- "shape_unsafe" , # It can mask invalid mask sizes
19482101 use_db_name_as_tag = False , # Not included if only "specialize" is requested
19492102)
19502103
0 commit comments