Skip to content

Commit a62e785

Browse files
committed
Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched
1 parent efc9d69 commit a62e785

File tree

4 files changed

+105
-56
lines changed

4 files changed

+105
-56
lines changed

pytensor/link/jax/dispatch/blockwise.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
import jax.numpy as jnp
22

3-
from pytensor.graph import FunctionGraph
43
from pytensor.link.jax.dispatch import jax_funcify
54
from pytensor.tensor.blockwise import Blockwise
65

76

87
@jax_funcify.register(Blockwise)
9-
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
8+
def jax_funcify_Blockwise(op: Blockwise, node, **kwargs):
109
signature = op.signature
11-
core_node = op._create_dummy_core_node(node.inputs)
12-
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13-
tuple_core_fn = jax_funcify(core_fgraph)
14-
15-
if len(node.outputs) == 1:
16-
17-
def core_fn(*inputs):
18-
return tuple_core_fn(*inputs)[0]
19-
20-
else:
21-
core_fn = tuple_core_fn
10+
core_node = op._create_dummy_core_node(
11+
node.inputs, propagate_unbatched_core_inputs=True
12+
)
13+
core_fn = jax_funcify(core_node.op, node=core_node, **kwargs)
2214

2315
vect_fn = jnp.vectorize(core_fn, signature=signature)
2416

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
1717

1818

19-
@numba_funcify.register
19+
@numba_funcify.register(BlockwiseWithCoreShape)
2020
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2121
[blockwise_node] = op.fgraph.apply_nodes
2222
blockwise_op: Blockwise = blockwise_node.op
@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2626
core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
2727

2828
core_node = blockwise_op._create_dummy_core_node(
29-
cast(tuple[TensorVariable], blockwise_node.inputs)
29+
cast(tuple[TensorVariable], node.inputs[:nin]),
30+
propagate_unbatched_core_inputs=True,
3031
)
3132
core_op_fn = numba_funcify(
3233
core_op,

pytensor/tensor/blockwise.py

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Sequence
2-
from typing import Any, cast
2+
from typing import Any, Literal, cast, overload
33

44
import numpy as np
55
from numpy import broadcast_shapes, empty
@@ -32,6 +32,17 @@
3232
from pytensor.tensor.variable import TensorVariable
3333

3434

35+
def _squeeze_left(x, stop_at_dim: int | None = None):
36+
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
37+
x_dims = x.type.broadcastable
38+
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
39+
if stop_at_dim is not None:
40+
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
41+
if squeeze_ndim == 0:
42+
return x
43+
return x.squeeze(axis=tuple(range(squeeze_ndim)))
44+
45+
3546
def _vectorize_node_perform(
3647
core_node: Apply,
3748
batch_bcast_patterns: Sequence[tuple[bool, ...]],
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
143154
class Blockwise(COp):
144155
"""Generalizes a core `Op` to work with batched dimensions.
145156
146-
TODO: Dispatch JAX (should be easy with the vectorize macro)
147-
TODO: Dispatch Numba
148157
TODO: C implementation?
149158
TODO: Fuse Blockwise?
150159
"""
@@ -202,21 +211,52 @@ def __init__(
202211

203212
super().__init__(**kwargs)
204213

205-
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
206-
core_input_types = []
214+
@overload
215+
def _create_dummy_core_node(
216+
self,
217+
inputs: Sequence[TensorVariable],
218+
*,
219+
propagate_unbatched_core_inputs: bool = False,
220+
return_dummy_inputs: Literal[False] = ...,
221+
) -> Apply: ...
222+
223+
@overload
224+
def _create_dummy_core_node(
225+
self,
226+
inputs: Sequence[TensorVariable],
227+
*,
228+
propagate_unbatched_core_inputs: bool = False,
229+
return_dummy_inputs: Literal[True] = ...,
230+
) -> tuple[Apply, list[TensorVariable]]: ...
231+
232+
def _create_dummy_core_node(
233+
self,
234+
inputs: Sequence[TensorVariable],
235+
*,
236+
propagate_unbatched_core_inputs: bool = False,
237+
return_dummy_inputs: bool = False,
238+
) -> Apply | tuple[Apply, list[TensorVariable]]:
239+
core_inputs = []
240+
core_dummy_inputs = []
207241
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
208242
if inp.type.ndim < len(sig):
209243
raise ValueError(
210244
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
211245
)
212246
# ndim_supp = 0 case
213-
if not sig:
214-
core_shape = ()
247+
inp_ndim = inp.type.ndim
248+
batch_ndim = inp_ndim - len(sig)
249+
core_shape = inp.type.shape[batch_ndim:]
250+
if propagate_unbatched_core_inputs and all(
251+
inp.type.broadcastable[:batch_ndim]
252+
):
253+
core_inputs.append(_squeeze_left(inp, batch_ndim))
215254
else:
216-
core_shape = inp.type.shape[-len(sig) :]
217-
core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape))
255+
dummy_inp = tensor(dtype=inp.type.dtype, shape=core_shape)
256+
core_inputs.append(dummy_inp)
257+
core_dummy_inputs.append(dummy_inp)
218258

219-
core_node = self.core_op.make_node(*core_input_types)
259+
core_node = self.core_op.make_node(*core_inputs)
220260

221261
if len(core_node.outputs) != len(self.outputs_sig):
222262
raise ValueError(
@@ -230,6 +270,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
230270
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
231271
)
232272

273+
if return_dummy_inputs:
274+
return core_node, core_dummy_inputs
275+
233276
return core_node
234277

235278
def make_node(self, *inputs):
@@ -298,11 +341,17 @@ def infer_shape(
298341

299342
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
300343

301-
# Try to extract the core shapes from the core_op
302-
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
303-
if core_op_infer_shape is not None:
304-
dummy_core_node = self._create_dummy_core_node(node.inputs)
305-
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
344+
def extract_core_shape_from_infer_shape():
345+
# Try to extract the core shapes from the core_op
346+
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
347+
if core_op_infer_shape is None:
348+
return [[None] * out.ndim for out in node.outputs]
349+
350+
dummy_core_node, dummy_core_inputs = self._create_dummy_core_node(
351+
node.inputs,
352+
return_dummy_inputs=True,
353+
propagate_unbatched_core_inputs=True,
354+
)
306355
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
307356
core_input_shapes = [
308357
input_shape[batch_ndims:] for input_shape in input_shapes
@@ -311,6 +360,25 @@ def infer_shape(
311360
dummy_fgraph, dummy_core_node, core_input_shapes
312361
)
313362

363+
# Set to None those core_shapes that depend on dummy_core_inputs,
364+
# meaning their value may not be constant across batch dims of the Blockwise
365+
if not dummy_core_inputs:
366+
# All inputs are unbatched, so the core_shape can be used as is
367+
return core_output_shapes
368+
else:
369+
set_dummy_core_inputs = set(dummy_core_inputs)
370+
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
371+
for core_out_shape in safe_core_output_shapes:
372+
for o, core_out_dim in enumerate(core_out_shape):
373+
if set_dummy_core_inputs & set(
374+
explicit_graph_inputs([core_out_dim])
375+
):
376+
core_out_shape[o] = None
377+
378+
return safe_core_output_shapes
379+
380+
safe_core_out_shape = None
381+
314382
out_shapes = []
315383
for o, (output, sig) in enumerate(
316384
zip(node.outputs, self.outputs_sig, strict=True)
@@ -321,19 +389,15 @@ def infer_shape(
321389
if dim_name in core_dims:
322390
core_out_shape.append(core_dims[dim_name])
323391
else:
324-
if core_op_infer_shape is not None:
325-
# If the input values are needed to compute the dimension length, we can't use the infer_shape
326-
# of the core_node as the value is not constant across batch dims of the Blockwise
327-
core_out_dim = core_output_shapes[o][i]
328-
if not (
329-
set(dummy_core_inputs)
330-
& set(explicit_graph_inputs([core_out_dim]))
331-
):
332-
core_out_shape.append(core_out_dim)
333-
continue
334-
335-
# Fallback shape requires evaluating the Blockwise Op
336-
core_out_shape.append(Shape_i(batch_ndims + i)(output))
392+
if safe_core_out_shape is None:
393+
# Extract the core shape from the core_op infer_shape on demand
394+
# For many Ops we never need to do this, because all info is in their signature
395+
safe_core_out_shape = extract_core_shape_from_infer_shape()
396+
if (core_out_dim := safe_core_out_shape[o][i]) is not None:
397+
core_out_shape.append(core_out_dim)
398+
else:
399+
# Fallback shape requires evaluating the Blockwise Op
400+
core_out_shape.append(Shape_i(batch_ndims + i)(output))
337401
out_shapes.append((*batch_shape, *core_out_shape))
338402

339403
return out_shapes
@@ -448,7 +512,10 @@ def gufunc(
448512
)
449513
return core_func(*inputs)
450514
else:
451-
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
515+
core_node = self._create_dummy_core_node(
516+
cast(list[TensorVariable], node.inputs),
517+
propagate_unbatched_core_inputs=True,
518+
)
452519
gufunc = _vectorize_node_perform(
453520
core_node,
454521
batch_bcast_patterns=batch_bcast_patterns,

pytensor/tensor/rewriting/blockwise.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pytensor.graph.replace import vectorize_node
55
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
66
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
7-
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
88
from pytensor.tensor.math import Dot
99
from pytensor.tensor.rewriting.basic import (
1010
register_canonicalize,
@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
9090
return local_useless_unbatched_blockwise.fn(fgraph, node)
9191

9292

93-
def _squeeze_left(x, stop_at_dim: int | None = None):
94-
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
95-
x_dims = x.type.broadcastable
96-
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
97-
if stop_at_dim is not None:
98-
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
99-
if squeeze_ndim == 0:
100-
return x
101-
return x.squeeze(axis=tuple(range(squeeze_ndim)))
102-
103-
10493
@register_specialize("shape_unsafe")
10594
@node_rewriter([Blockwise])
10695
def local_blockwise_alloc(fgraph, node):

0 commit comments

Comments
 (0)