Skip to content

Commit 9ad9540

Browse files
block_diag dot rewrite
1 parent d3bbc20 commit 9ad9540

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
cast,
3030
constant,
3131
get_underlying_scalar_constant_value,
32+
join,
3233
moveaxis,
3334
ones_like,
3435
register_infer_shape,
36+
split,
3537
switch,
3638
zeros_like,
3739
)
@@ -99,6 +101,7 @@
99101
)
100102
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
101103
from pytensor.tensor.shape import Shape, Shape_i
104+
from pytensor.tensor.slinalg import BlockDiagonal
102105
from pytensor.tensor.subtensor import Subtensor
103106
from pytensor.tensor.type import (
104107
complex_dtypes,
@@ -167,6 +170,72 @@ def local_0_dot_x(fgraph, node):
167170
return [constant_zero]
168171

169172

173+
@register_canonicalize
174+
@register_specialize
175+
@register_stabilize
176+
@node_rewriter([Dot])
177+
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
178+
r"""
179+
Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
180+
181+
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
182+
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
183+
a single dot on the larger matrix.
184+
"""
185+
x, y = node.inputs
186+
op = node.op
187+
188+
def check_for_block_diag(x):
189+
return x.owner and (
190+
isinstance(x.owner.op, BlockDiagonal)
191+
or isinstance(x.owner.op, Blockwise)
192+
and isinstance(x.owner.op.core_op, BlockDiagonal)
193+
)
194+
195+
if not (check_for_block_diag(x) or check_for_block_diag(y)):
196+
return None
197+
198+
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
199+
# non-block diagonal, and return a new block diagonal
200+
if check_for_block_diag(x) and not check_for_block_diag(y):
201+
components = x.owner.inputs
202+
y_splits = split(
203+
y,
204+
splits_size=[component.shape[-1] for component in components],
205+
n_splits=len(components),
206+
)
207+
new_components = [
208+
op(component, y_split) for component, y_split in zip(components, y_splits)
209+
]
210+
new_output = join(0, *new_components)
211+
elif not check_for_block_diag(x) and check_for_block_diag(y):
212+
components = y.owner.inputs
213+
new_components = [op(x, component) for component in components]
214+
new_output = join(0, *new_components)
215+
216+
# Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
217+
# that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
218+
elif any(shape is None for shape in (*x.type.shape, *y.type.shape)):
219+
return None
220+
elif x.ndim == y.ndim and all(
221+
x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape)
222+
):
223+
x_components = x.owner.inputs
224+
y_components = y.owner.inputs
225+
226+
if len(x_components) != len(y_components):
227+
return None
228+
229+
new_output = BlockDiagonal(len(x_components))(
230+
*[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)]
231+
)
232+
else:
233+
return None
234+
235+
copy_stack_trace(node.outputs[0], new_output)
236+
return [new_output]
237+
238+
170239
@register_canonicalize
171240
@node_rewriter([DimShuffle])
172241
def local_lift_transpose_through_dot(fgraph, node):
@@ -2496,7 +2565,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
24962565
name="add_canonizer_group",
24972566
)
24982567

2499-
25002568
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
25012569

25022570

@@ -3619,7 +3687,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
36193687
)
36203688
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
36213689

3622-
36233690
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
36243691
# i.e logit(sigmoid(x)) -> x
36253692
local_logit_sigmoid = PatternNodeRewriter(
@@ -3633,7 +3700,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
36333700
register_canonicalize(local_logit_sigmoid)
36343701
register_specialize(local_logit_sigmoid)
36353702

3636-
36373703
# sigmoid(log(x / (1-x)) -> x
36383704
# i.e., sigmoid(logit(x)) -> x
36393705
local_sigmoid_logit = PatternNodeRewriter(
@@ -3674,7 +3740,6 @@ def local_useless_conj(fgraph, node):
36743740

36753741
register_specialize(local_polygamma_to_tri_gamma)
36763742

3677-
36783743
local_log_kv = PatternNodeRewriter(
36793744
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
36803745
# During stabilize -x is converted to -1.0 * x

tests/tensor/rewriting/test_math.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
simplify_mul,
114114
)
115115
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
116+
from pytensor.tensor.slinalg import BlockDiagonal
116117
from pytensor.tensor.type import (
117118
TensorType,
118119
cmatrix,
@@ -4654,3 +4655,75 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
46544655
out.eval({a: a_test, b: b_test}, mode=test_mode),
46554656
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
46564657
)
4658+
4659+
4660+
def test_local_block_diag_dot_to_dot_block_diag():
4661+
"""
4662+
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4663+
"""
4664+
a = tensor("a", shape=(4, 2))
4665+
b = tensor("b", shape=(2, 4))
4666+
c = tensor("c", shape=(4, 4))
4667+
d = tensor("d", shape=(10,))
4668+
4669+
x = pt.linalg.block_diag(a, b, c)
4670+
out = x @ d
4671+
4672+
fn = pytensor.function([a, b, c, d], out)
4673+
assert not any(
4674+
isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort()
4675+
)
4676+
4677+
fn_expected = pytensor.function(
4678+
[a, b, c, d],
4679+
out,
4680+
mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"),
4681+
)
4682+
4683+
rng = np.random.default_rng()
4684+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4685+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4686+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4687+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4688+
4689+
np.testing.assert_allclose(
4690+
fn(a_val, b_val, c_val, d_val),
4691+
fn_expected(a_val, b_val, c_val, d_val),
4692+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4693+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4694+
)
4695+
4696+
4697+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
4698+
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
4699+
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
4700+
rng = np.random.default_rng()
4701+
a_size = int(rng.uniform(0, size))
4702+
b_size = int(rng.uniform(0, size - a_size))
4703+
c_size = size - a_size - b_size
4704+
4705+
a = tensor("a", shape=(a_size, a_size))
4706+
b = tensor("b", shape=(b_size, b_size))
4707+
c = tensor("c", shape=(c_size, c_size))
4708+
d = tensor("d", shape=(size,))
4709+
4710+
x = pt.linalg.block_diag(a, b, c)
4711+
out = x @ d
4712+
4713+
mode = get_default_mode()
4714+
if not rewrite:
4715+
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
4716+
fn = pytensor.function([a, b, c, d], out, mode=mode)
4717+
4718+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4719+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4720+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4721+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4722+
4723+
benchmark(
4724+
fn,
4725+
a_val,
4726+
b_val,
4727+
c_val,
4728+
d_val,
4729+
)

0 commit comments

Comments
 (0)