Skip to content

Commit 10c36d2

Browse files
committed
Apply useless blockwise rewrite when there are only dummy batch dims
Also extend eager rewrite to more Ops The Blockwise MatrixInverse grad test became more sensitive in float32, because desired stabilization rewrites (mainly `inv_as_solve`) that target Dot of Blockwise{MatrixInverse} are now triggered in the default blockwise grad but not in the non-default non-blockwise grad
1 parent fe5865e commit 10c36d2

File tree

3 files changed

+43
-12
lines changed

3 files changed

+43
-12
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,16 @@ def make_node(self, *inputs):
163163

164164
return Apply(self, batched_inputs, batched_outputs)
165165

166-
def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int:
167-
return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0]))
166+
def batch_ndim(self, node: Apply) -> int:
167+
return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0]))
168168

169169
def infer_shape(
170170
self, fgraph, node, input_shapes
171171
) -> list[tuple[TensorVariable, ...]]:
172172
from pytensor.tensor import broadcast_shape
173173
from pytensor.tensor.shape import Shape_i
174174

175-
batch_ndims = self._batch_ndim_from_outputs(node.outputs)
175+
batch_ndims = self.batch_ndim(node)
176176
core_dims: dict[str, Any] = {}
177177
batch_shapes = []
178178
for input_shape, sig in zip(input_shapes, self.inputs_sig):
@@ -278,7 +278,7 @@ def L_op(self, inputs, outs, ograds):
278278
return new_rval
279279

280280
# Sum out the broadcasted dimensions
281-
batch_ndims = self._batch_ndim_from_outputs(outs)
281+
batch_ndims = self.batch_ndim(outs[0].owner)
282282
batch_shape = outs[0].type.shape[:batch_ndims]
283283
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
284284
if isinstance(rval[i].type, (NullType, DisconnectedType)):
@@ -320,7 +320,7 @@ def core_func(*inner_inputs):
320320
return self._gufunc
321321

322322
def _check_runtime_broadcast(self, node, inputs):
323-
batch_ndim = self._batch_ndim_from_outputs(node.outputs)
323+
batch_ndim = self.batch_ndim(node)
324324

325325
for dims_and_bcast in zip(
326326
*[

pytensor/tensor/rewriting/blockwise.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
from pytensor.graph import node_rewriter
33
from pytensor.graph.replace import vectorize_node
44
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
5+
from pytensor.tensor.basic import Alloc, ARange, shape_padleft
56
from pytensor.tensor.blockwise import Blockwise
6-
from pytensor.tensor.math import _matrix_matrix_matmul
7-
from pytensor.tensor.rewriting.basic import register_canonicalize
7+
from pytensor.tensor.math import Dot
8+
from pytensor.tensor.rewriting.basic import (
9+
register_canonicalize,
10+
register_specialize,
11+
register_stabilize,
12+
)
13+
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
814

915

1016
@node_rewriter([Blockwise])
@@ -29,8 +35,17 @@ def local_useless_unbatched_blockwise(fgraph, node):
2935
op = node.op
3036
inputs = node.inputs
3137

32-
if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
33-
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)
38+
batch_ndims = node.op.batch_ndim(node)
39+
if all(all(inp.type.broadcastable[:batch_ndims]) for inp in inputs):
40+
if batch_ndims:
41+
# Remove dummy batch dims
42+
axis = tuple(range(batch_ndims))
43+
inputs = [inp.squeeze(axis) for inp in inputs]
44+
new_outs = op.core_op.make_node(*inputs).outputs
45+
if batch_ndims:
46+
# Reintroduce dummy batch dims
47+
new_outs = [shape_padleft(out, batch_ndims) for out in new_outs]
48+
return copy_stack_trace(node.outputs, new_outs)
3449

3550

3651
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
@@ -46,6 +61,22 @@ def local_useless_unbatched_blockwise(fgraph, node):
4661

4762
# Avoid redundant cases early on for Ops whose default form is not Blockwised
4863
@register_canonicalize
49-
@node_rewriter(tracks=[_matrix_matrix_matmul])
64+
@register_stabilize
65+
@register_specialize
66+
@node_rewriter(tracks=[Blockwise])
5067
def local_eager_useless_unbatched_blockwise(fgraph, node):
51-
return local_useless_unbatched_blockwise.fn(fgraph, node)
68+
if isinstance(
69+
node.op.core_op,
70+
(
71+
# Many Dot-related rewrites (e.g., all of BlasOpt) happen before specialize
72+
Dot,
73+
# These Ops can't always be trivially vectorized at runtime,
74+
# Since their inputs may imply non-rectangular shapes.
75+
Alloc,
76+
ARange,
77+
Subtensor,
78+
AdvancedSubtensor,
79+
AdvancedIncSubtensor,
80+
),
81+
):
82+
return local_useless_unbatched_blockwise.fn(fgraph, node)

tests/tensor/test_blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_grad(self):
293293
pt_out,
294294
np_out,
295295
rtol=1e-7 if config.floatX == "float64" else 1e-5,
296-
atol=1e-6 if config.floatX == "float64" else 1e-5,
296+
atol=1e-6 if config.floatX == "float64" else 1e-4,
297297
)
298298

299299

0 commit comments

Comments
 (0)