Skip to content

Commit 5b1db5a

Browse files
committed
Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched
1 parent 2084ded commit 5b1db5a

File tree

4 files changed

+84
-55
lines changed

4 files changed

+84
-55
lines changed

pytensor/link/jax/dispatch/blockwise.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
import jax.numpy as jnp
22

3-
from pytensor.graph import FunctionGraph
43
from pytensor.link.jax.dispatch import jax_funcify
54
from pytensor.tensor.blockwise import Blockwise
65

76

87
@jax_funcify.register(Blockwise)
9-
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
8+
def jax_funcify_Blockwise(op: Blockwise, node, **kwargs):
109
signature = op.signature
11-
core_node = op._create_dummy_core_node(node.inputs)
12-
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13-
tuple_core_fn = jax_funcify(core_fgraph)
14-
15-
if len(node.outputs) == 1:
16-
17-
def core_fn(*inputs):
18-
return tuple_core_fn(*inputs)[0]
19-
20-
else:
21-
core_fn = tuple_core_fn
10+
core_node = op._create_dummy_core_node(
11+
node.inputs, propagate_unbatched_core_inputs=True
12+
)
13+
core_fn = jax_funcify(core_node.op, node=core_node, **kwargs)
2214

2315
vect_fn = jnp.vectorize(core_fn, signature=signature)
2416

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
1717

1818

19-
@numba_funcify.register
19+
@numba_funcify.register(BlockwiseWithCoreShape)
2020
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2121
[blockwise_node] = op.fgraph.apply_nodes
2222
blockwise_op: Blockwise = blockwise_node.op
@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2626
core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
2727

2828
core_node = blockwise_op._create_dummy_core_node(
29-
cast(tuple[TensorVariable], blockwise_node.inputs)
29+
cast(tuple[TensorVariable], node.inputs[:nin]),
30+
propagate_unbatched_core_inputs=True,
3031
)
3132
core_op_fn = numba_funcify(
3233
core_op,

pytensor/tensor/blockwise.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@
3232
from pytensor.tensor.variable import TensorVariable
3333

3434

35+
def _squeeze_left(x, stop_at_dim: int | None = None):
36+
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
37+
x_dims = x.type.broadcastable
38+
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
39+
if stop_at_dim is not None:
40+
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
41+
if squeeze_ndim == 0:
42+
return x
43+
return x.squeeze(axis=tuple(range(squeeze_ndim)))
44+
45+
3546
def _vectorize_node_perform(
3647
core_node: Apply,
3748
batch_bcast_patterns: Sequence[tuple[bool, ...]],
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
143154
class Blockwise(COp):
144155
"""Generalizes a core `Op` to work with batched dimensions.
145156
146-
TODO: Dispatch JAX (should be easy with the vectorize macro)
147-
TODO: Dispatch Numba
148157
TODO: C implementation?
149158
TODO: Fuse Blockwise?
150159
"""
@@ -202,21 +211,33 @@ def __init__(
202211

203212
super().__init__(**kwargs)
204213

205-
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
206-
core_input_types = []
214+
def _create_dummy_core_node(
215+
self,
216+
inputs: Sequence[TensorVariable],
217+
propagate_unbatched_core_inputs: bool = False,
218+
return_dummy_inputs: bool = False,
219+
) -> Apply:
220+
core_inputs = []
221+
core_dummy_inputs = []
207222
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
208223
if inp.type.ndim < len(sig):
209224
raise ValueError(
210225
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
211226
)
212227
# ndim_supp = 0 case
213-
if not sig:
214-
core_shape = ()
228+
inp_ndim = inp.type.ndim
229+
batch_ndim = inp_ndim - len(sig)
230+
core_shape = inp.type.shape[batch_ndim:]
231+
if propagate_unbatched_core_inputs and all(
232+
inp.type.broadcastable[:batch_ndim]
233+
):
234+
core_inputs.append(_squeeze_left(inp, batch_ndim))
215235
else:
216-
core_shape = inp.type.shape[-len(sig) :]
217-
core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape))
236+
dummy_inp = tensor(dtype=inp.type.dtype, shape=core_shape)
237+
core_inputs.append(dummy_inp)
238+
core_dummy_inputs.append(dummy_inp)
218239

219-
core_node = self.core_op.make_node(*core_input_types)
240+
core_node = self.core_op.make_node(*core_inputs)
220241

221242
if len(core_node.outputs) != len(self.outputs_sig):
222243
raise ValueError(
@@ -230,6 +251,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
230251
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
231252
)
232253

254+
if return_dummy_inputs:
255+
return core_node, core_dummy_inputs
256+
233257
return core_node
234258

235259
def make_node(self, *inputs):
@@ -298,11 +322,17 @@ def infer_shape(
298322

299323
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
300324

301-
# Try to extract the core shapes from the core_op
302-
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
303-
if core_op_infer_shape is not None:
304-
dummy_core_node = self._create_dummy_core_node(node.inputs)
305-
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
325+
def extract_core_shape_from_infer_shape():
326+
# Try to extract the core shapes from the core_op
327+
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
328+
if core_op_infer_shape is None:
329+
return [[None] * out.ndim for out in node.outputs]
330+
331+
dummy_core_node, dummy_core_inputs = self._create_dummy_core_node(
332+
node.inputs,
333+
return_dummy_inputs=True,
334+
propagate_unbatched_core_inputs=True,
335+
)
306336
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
307337
core_input_shapes = [
308338
input_shape[batch_ndims:] for input_shape in input_shapes
@@ -311,6 +341,25 @@ def infer_shape(
311341
dummy_fgraph, dummy_core_node, core_input_shapes
312342
)
313343

344+
# Set to None those core_shapes that depend on dummy_core_inputs,
345+
# meaning their value may not be constant across batch dims of the Blockwise
346+
if not dummy_core_inputs:
347+
# All inputs are unbatched, so the core_shape can be used as is
348+
return core_output_shapes
349+
else:
350+
set_dummy_core_inputs = set(dummy_core_inputs)
351+
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
352+
for core_out_shape in safe_core_output_shapes:
353+
for o, core_out_dim in enumerate(core_out_shape):
354+
if set_dummy_core_inputs & set(
355+
explicit_graph_inputs([core_out_dim])
356+
):
357+
core_out_shape[o] = None
358+
359+
return safe_core_output_shapes
360+
361+
safe_core_out_shape = None
362+
314363
out_shapes = []
315364
for o, (output, sig) in enumerate(
316365
zip(node.outputs, self.outputs_sig, strict=True)
@@ -321,19 +370,15 @@ def infer_shape(
321370
if dim_name in core_dims:
322371
core_out_shape.append(core_dims[dim_name])
323372
else:
324-
if core_op_infer_shape is not None:
325-
# If the input values are needed to compute the dimension length, we can't use the infer_shape
326-
# of the core_node as the value is not constant across batch dims of the Blockwise
327-
core_out_dim = core_output_shapes[o][i]
328-
if not (
329-
set(dummy_core_inputs)
330-
& set(explicit_graph_inputs([core_out_dim]))
331-
):
332-
core_out_shape.append(core_out_dim)
333-
continue
334-
335-
# Fallback shape requires evaluating the Blockwise Op
336-
core_out_shape.append(Shape_i(batch_ndims + i)(output))
373+
if safe_core_out_shape is None:
374+
# Extract the core shape from the core_op infer_shape on demand
375+
# For many Ops we never need to do this, because all info is in their signature
376+
safe_core_out_shape = extract_core_shape_from_infer_shape()
377+
if (core_out_dim := safe_core_out_shape[o][i]) is not None:
378+
core_out_shape.append(core_out_dim)
379+
else:
380+
# Fallback shape requires evaluating the Blockwise Op
381+
core_out_shape.append(Shape_i(batch_ndims + i)(output))
337382
out_shapes.append((*batch_shape, *core_out_shape))
338383

339384
return out_shapes
@@ -448,7 +493,9 @@ def gufunc(
448493
)
449494
return core_func(*inputs)
450495
else:
451-
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
496+
core_node = self._create_dummy_core_node(
497+
node.inputs, propagate_unbatched_core_inputs=True
498+
) # type: ignore
452499
gufunc = _vectorize_node_perform(
453500
core_node,
454501
batch_bcast_patterns=batch_bcast_patterns,

pytensor/tensor/rewriting/blockwise.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pytensor.graph.replace import vectorize_node
55
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
66
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
7-
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
88
from pytensor.tensor.math import Dot
99
from pytensor.tensor.rewriting.basic import (
1010
register_canonicalize,
@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
9090
return local_useless_unbatched_blockwise.fn(fgraph, node)
9191

9292

93-
def _squeeze_left(x, stop_at_dim: int | None = None):
94-
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
95-
x_dims = x.type.broadcastable
96-
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
97-
if stop_at_dim is not None:
98-
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
99-
if squeeze_ndim == 0:
100-
return x
101-
return x.squeeze(axis=tuple(range(squeeze_ndim)))
102-
103-
10493
@register_specialize("shape_unsafe")
10594
@node_rewriter([Blockwise])
10695
def local_blockwise_alloc(fgraph, node):

0 commit comments

Comments
 (0)