1
1
from collections .abc import Callable , Sequence
2
- from typing import Any , cast
2
+ from typing import Any , Literal , cast , overload
3
3
4
4
import numpy as np
5
5
from numpy import broadcast_shapes , empty
32
32
from pytensor .tensor .variable import TensorVariable
33
33
34
34
35
+ def _squeeze_left (x , stop_at_dim : int | None = None ):
36
+ """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
37
+ x_dims = x .type .broadcastable
38
+ squeeze_ndim = len (x_dims ) if all (x_dims ) else x_dims .index (False )
39
+ if stop_at_dim is not None :
40
+ squeeze_ndim = min (squeeze_ndim , stop_at_dim )
41
+ if squeeze_ndim == 0 :
42
+ return x
43
+ return x .squeeze (axis = tuple (range (squeeze_ndim )))
44
+
45
+
35
46
def _vectorize_node_perform (
36
47
core_node : Apply ,
37
48
batch_bcast_patterns : Sequence [tuple [bool , ...]],
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
143
154
class Blockwise (COp ):
144
155
"""Generalizes a core `Op` to work with batched dimensions.
145
156
146
- TODO: Dispatch JAX (should be easy with the vectorize macro)
147
- TODO: Dispatch Numba
148
157
TODO: C implementation?
149
158
TODO: Fuse Blockwise?
150
159
"""
@@ -202,21 +211,52 @@ def __init__(
202
211
203
212
super ().__init__ (** kwargs )
204
213
205
- def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
206
- core_input_types = []
214
+ @overload
215
+ def _create_dummy_core_node (
216
+ self ,
217
+ inputs : Sequence [TensorVariable ],
218
+ * ,
219
+ propagate_unbatched_core_inputs : bool = False ,
220
+ return_dummy_inputs : Literal [False ] = ...,
221
+ ) -> Apply : ...
222
+
223
+ @overload
224
+ def _create_dummy_core_node (
225
+ self ,
226
+ inputs : Sequence [TensorVariable ],
227
+ * ,
228
+ propagate_unbatched_core_inputs : bool = False ,
229
+ return_dummy_inputs : Literal [True ] = ...,
230
+ ) -> tuple [Apply , list [TensorVariable ]]: ...
231
+
232
+ def _create_dummy_core_node (
233
+ self ,
234
+ inputs : Sequence [TensorVariable ],
235
+ * ,
236
+ propagate_unbatched_core_inputs : bool = False ,
237
+ return_dummy_inputs : bool = False ,
238
+ ) -> Apply | tuple [Apply , list [TensorVariable ]]:
239
+ core_inputs = []
240
+ core_dummy_inputs = []
207
241
for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
208
242
if inp .type .ndim < len (sig ):
209
243
raise ValueError (
210
244
f"Input { i } { inp } has insufficient core dimensions for signature { self .signature } "
211
245
)
212
246
# ndim_supp = 0 case
213
- if not sig :
214
- core_shape = ()
247
+ inp_ndim = inp .type .ndim
248
+ batch_ndim = inp_ndim - len (sig )
249
+ core_shape = inp .type .shape [batch_ndim :]
250
+ if propagate_unbatched_core_inputs and all (
251
+ inp .type .broadcastable [:batch_ndim ]
252
+ ):
253
+ core_inputs .append (_squeeze_left (inp , batch_ndim ))
215
254
else :
216
- core_shape = inp .type .shape [- len (sig ) :]
217
- core_input_types .append (tensor (dtype = inp .type .dtype , shape = core_shape ))
255
+ dummy_inp = tensor (dtype = inp .type .dtype , shape = core_shape )
256
+ core_inputs .append (dummy_inp )
257
+ core_dummy_inputs .append (dummy_inp )
218
258
219
- core_node = self .core_op .make_node (* core_input_types )
259
+ core_node = self .core_op .make_node (* core_inputs )
220
260
221
261
if len (core_node .outputs ) != len (self .outputs_sig ):
222
262
raise ValueError (
@@ -230,6 +270,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
230
270
f"Output { i } of { self .core_op } has wrong number of core dimensions for signature { self .signature } : { core_out .type .ndim } "
231
271
)
232
272
273
+ if return_dummy_inputs :
274
+ return core_node , core_dummy_inputs
275
+
233
276
return core_node
234
277
235
278
def make_node (self , * inputs ):
@@ -298,11 +341,17 @@ def infer_shape(
298
341
299
342
batch_shape = broadcast_shape (* batch_shapes , arrays_are_shapes = True )
300
343
301
- # Try to extract the core shapes from the core_op
302
- core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
303
- if core_op_infer_shape is not None :
304
- dummy_core_node = self ._create_dummy_core_node (node .inputs )
305
- dummy_core_inputs = tuple (explicit_graph_inputs (dummy_core_node .inputs ))
344
+ def extract_core_shape_from_infer_shape ():
345
+ # Try to extract the core shapes from the core_op
346
+ core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
347
+ if core_op_infer_shape is None :
348
+ return [[None ] * out .ndim for out in node .outputs ]
349
+
350
+ dummy_core_node , dummy_core_inputs = self ._create_dummy_core_node (
351
+ node .inputs ,
352
+ return_dummy_inputs = True ,
353
+ propagate_unbatched_core_inputs = True ,
354
+ )
306
355
dummy_fgraph = FunctionGraph (outputs = dummy_core_node .outputs , clone = False )
307
356
core_input_shapes = [
308
357
input_shape [batch_ndims :] for input_shape in input_shapes
@@ -311,6 +360,25 @@ def infer_shape(
311
360
dummy_fgraph , dummy_core_node , core_input_shapes
312
361
)
313
362
363
+ # Set to None those core_shapes that depend on dummy_core_inputs,
364
+ # meaning their value may not be constant across batch dims of the Blockwise
365
+ if not dummy_core_inputs :
366
+ # All inputs are unbatched, so the core_shape can be used as is
367
+ return core_output_shapes
368
+ else :
369
+ set_dummy_core_inputs = set (dummy_core_inputs )
370
+ safe_core_output_shapes = [list (shape ) for shape in core_output_shapes ]
371
+ for core_out_shape in safe_core_output_shapes :
372
+ for o , core_out_dim in enumerate (core_out_shape ):
373
+ if set_dummy_core_inputs & set (
374
+ explicit_graph_inputs ([core_out_dim ])
375
+ ):
376
+ core_out_shape [o ] = None
377
+
378
+ return safe_core_output_shapes
379
+
380
+ safe_core_out_shape = None
381
+
314
382
out_shapes = []
315
383
for o , (output , sig ) in enumerate (
316
384
zip (node .outputs , self .outputs_sig , strict = True )
@@ -321,19 +389,15 @@ def infer_shape(
321
389
if dim_name in core_dims :
322
390
core_out_shape .append (core_dims [dim_name ])
323
391
else :
324
- if core_op_infer_shape is not None :
325
- # If the input values are needed to compute the dimension length, we can't use the infer_shape
326
- # of the core_node as the value is not constant across batch dims of the Blockwise
327
- core_out_dim = core_output_shapes [o ][i ]
328
- if not (
329
- set (dummy_core_inputs )
330
- & set (explicit_graph_inputs ([core_out_dim ]))
331
- ):
332
- core_out_shape .append (core_out_dim )
333
- continue
334
-
335
- # Fallback shape requires evaluating the Blockwise Op
336
- core_out_shape .append (Shape_i (batch_ndims + i )(output ))
392
+ if safe_core_out_shape is None :
393
+ # Extract the core shape from the core_op infer_shape on demand
394
+ # For many Ops we never need to do this, because all info is in their signature
395
+ safe_core_out_shape = extract_core_shape_from_infer_shape ()
396
+ if (core_out_dim := safe_core_out_shape [o ][i ]) is not None :
397
+ core_out_shape .append (core_out_dim )
398
+ else :
399
+ # Fallback shape requires evaluating the Blockwise Op
400
+ core_out_shape .append (Shape_i (batch_ndims + i )(output ))
337
401
out_shapes .append ((* batch_shape , * core_out_shape ))
338
402
339
403
return out_shapes
@@ -448,7 +512,10 @@ def gufunc(
448
512
)
449
513
return core_func (* inputs )
450
514
else :
451
- core_node = self ._create_dummy_core_node (node .inputs ) # type: ignore
515
+ core_node = self ._create_dummy_core_node (
516
+ cast (list [TensorVariable ], node .inputs ),
517
+ propagate_unbatched_core_inputs = True ,
518
+ )
452
519
gufunc = _vectorize_node_perform (
453
520
core_node ,
454
521
batch_bcast_patterns = batch_bcast_patterns ,
0 commit comments