3232from pytensor .tensor .variable import TensorVariable
3333
3434
35+ def _squeeze_left (x , stop_at_dim : int | None = None ):
36+ """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
37+ x_dims = x .type .broadcastable
38+ squeeze_ndim = len (x_dims ) if all (x_dims ) else x_dims .index (False )
39+ if stop_at_dim is not None :
40+ squeeze_ndim = min (squeeze_ndim , stop_at_dim )
41+ if squeeze_ndim == 0 :
42+ return x
43+ return x .squeeze (axis = tuple (range (squeeze_ndim )))
44+
45+
3546def _vectorize_node_perform (
3647 core_node : Apply ,
3748 batch_bcast_patterns : Sequence [tuple [bool , ...]],
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
143154class Blockwise (COp ):
144155 """Generalizes a core `Op` to work with batched dimensions.
145156
146- TODO: Dispatch JAX (should be easy with the vectorize macro)
147- TODO: Dispatch Numba
148157 TODO: C implementation?
149158 TODO: Fuse Blockwise?
150159 """
@@ -202,21 +211,33 @@ def __init__(
202211
203212 super ().__init__ (** kwargs )
204213
205- def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
206- core_input_types = []
214+ def _create_dummy_core_node (
215+ self ,
216+ inputs : Sequence [TensorVariable ],
217+ propagate_unbatched_core_inputs : bool = False ,
218+ return_dummy_inputs : bool = False ,
219+ ) -> Apply :
220+ core_inputs = []
221+ core_dummy_inputs = []
207222 for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
208223 if inp .type .ndim < len (sig ):
209224 raise ValueError (
210225 f"Input { i } { inp } has insufficient core dimensions for signature { self .signature } "
211226 )
212227 # ndim_supp = 0 case
213- if not sig :
214- core_shape = ()
228+ inp_ndim = inp .type .ndim
229+ batch_ndim = inp_ndim - len (sig )
230+ core_shape = inp .type .shape [batch_ndim :]
231+ if propagate_unbatched_core_inputs and all (
232+ inp .type .broadcastable [:batch_ndim ]
233+ ):
234+ core_inputs .append (_squeeze_left (inp , batch_ndim ))
215235 else :
216- core_shape = inp .type .shape [- len (sig ) :]
217- core_input_types .append (tensor (dtype = inp .type .dtype , shape = core_shape ))
236+ dummy_inp = tensor (dtype = inp .type .dtype , shape = core_shape )
237+ core_inputs .append (dummy_inp )
238+ core_dummy_inputs .append (dummy_inp )
218239
219- core_node = self .core_op .make_node (* core_input_types )
240+ core_node = self .core_op .make_node (* core_inputs )
220241
221242 if len (core_node .outputs ) != len (self .outputs_sig ):
222243 raise ValueError (
@@ -230,6 +251,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
230251 f"Output { i } of { self .core_op } has wrong number of core dimensions for signature { self .signature } : { core_out .type .ndim } "
231252 )
232253
254+ if return_dummy_inputs :
255+ return core_node , core_dummy_inputs
256+
233257 return core_node
234258
235259 def make_node (self , * inputs ):
@@ -298,11 +322,17 @@ def infer_shape(
298322
299323 batch_shape = broadcast_shape (* batch_shapes , arrays_are_shapes = True )
300324
301- # Try to extract the core shapes from the core_op
302- core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
303- if core_op_infer_shape is not None :
304- dummy_core_node = self ._create_dummy_core_node (node .inputs )
305- dummy_core_inputs = tuple (explicit_graph_inputs (dummy_core_node .inputs ))
325+ def extract_core_shape_from_infer_shape ():
326+ # Try to extract the core shapes from the core_op
327+ core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
328+ if core_op_infer_shape is None :
329+ return [[None ] * out .ndim for out in node .outputs ]
330+
331+ dummy_core_node , dummy_core_inputs = self ._create_dummy_core_node (
332+ node .inputs ,
333+ return_dummy_inputs = True ,
334+ propagate_unbatched_core_inputs = True ,
335+ )
306336 dummy_fgraph = FunctionGraph (outputs = dummy_core_node .outputs , clone = False )
307337 core_input_shapes = [
308338 input_shape [batch_ndims :] for input_shape in input_shapes
@@ -311,6 +341,25 @@ def infer_shape(
311341 dummy_fgraph , dummy_core_node , core_input_shapes
312342 )
313343
344+ # Set to None those core_shapes that depend on dummy_core_inputs,
345+ # meaning their value may not be constant across batch dims of the Blockwise
346+ if not dummy_core_inputs :
347+ # All inputs are unbatched, so the core_shape can be used as is
348+ return core_output_shapes
349+ else :
350+ set_dummy_core_inputs = set (dummy_core_inputs )
351+ safe_core_output_shapes = [list (shape ) for shape in core_output_shapes ]
352+ for core_out_shape in safe_core_output_shapes :
353+ for o , core_out_dim in enumerate (core_out_shape ):
354+ if set_dummy_core_inputs & set (
355+ explicit_graph_inputs ([core_out_dim ])
356+ ):
357+ core_out_shape [o ] = None
358+
359+ return safe_core_output_shapes
360+
361+ safe_core_out_shape = None
362+
314363 out_shapes = []
315364 for o , (output , sig ) in enumerate (
316365 zip (node .outputs , self .outputs_sig , strict = True )
@@ -321,19 +370,15 @@ def infer_shape(
321370 if dim_name in core_dims :
322371 core_out_shape .append (core_dims [dim_name ])
323372 else :
324- if core_op_infer_shape is not None :
325- # If the input values are needed to compute the dimension length, we can't use the infer_shape
326- # of the core_node as the value is not constant across batch dims of the Blockwise
327- core_out_dim = core_output_shapes [o ][i ]
328- if not (
329- set (dummy_core_inputs )
330- & set (explicit_graph_inputs ([core_out_dim ]))
331- ):
332- core_out_shape .append (core_out_dim )
333- continue
334-
335- # Fallback shape requires evaluating the Blockwise Op
336- core_out_shape .append (Shape_i (batch_ndims + i )(output ))
373+ if safe_core_out_shape is None :
374+ # Extract the core shape from the core_op infer_shape on demand
375+ # For many Ops we never need to do this, because all info is in their signature
376+ safe_core_out_shape = extract_core_shape_from_infer_shape ()
377+ if (core_out_dim := safe_core_out_shape [o ][i ]) is not None :
378+ core_out_shape .append (core_out_dim )
379+ else :
380+ # Fallback shape requires evaluating the Blockwise Op
381+ core_out_shape .append (Shape_i (batch_ndims + i )(output ))
337382 out_shapes .append ((* batch_shape , * core_out_shape ))
338383
339384 return out_shapes
@@ -448,7 +493,9 @@ def gufunc(
448493 )
449494 return core_func (* inputs )
450495 else :
451- core_node = self ._create_dummy_core_node (node .inputs ) # type: ignore
496+ core_node = self ._create_dummy_core_node (
497+ node .inputs , propagate_unbatched_core_inputs = True
498+ ) # type: ignore
452499 gufunc = _vectorize_node_perform (
453500 core_node ,
454501 batch_bcast_patterns = batch_bcast_patterns ,
0 commit comments