Skip to content

Commit d9d7566

Browse files
block_diag dot rewrite
1 parent d4e8f73 commit d9d7566

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
constant,
3030
expand_dims,
3131
get_underlying_scalar_constant_value,
32+
join,
3233
moveaxis,
3334
ones_like,
3435
register_infer_shape,
36+
split,
3537
switch,
3638
zeros,
3739
zeros_like,
@@ -96,6 +98,7 @@
9698
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
9799
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
98100
from pytensor.tensor.shape import Shape, Shape_i
101+
from pytensor.tensor.slinalg import BlockDiagonal
99102
from pytensor.tensor.subtensor import Subtensor
100103
from pytensor.tensor.type import (
101104
complex_dtypes,
@@ -146,6 +149,72 @@ def local_0_dot_x(fgraph, node):
146149
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
147150

148151

152+
@register_canonicalize
153+
@register_specialize
154+
@register_stabilize
155+
@node_rewriter([Dot])
156+
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
157+
r"""
158+
Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
159+
160+
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
161+
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
162+
a single dot on the larger matrix.
163+
"""
164+
x, y = node.inputs
165+
op = node.op
166+
167+
def check_for_block_diag(x):
168+
return x.owner and (
169+
isinstance(x.owner.op, BlockDiagonal)
170+
or isinstance(x.owner.op, Blockwise)
171+
and isinstance(x.owner.op.core_op, BlockDiagonal)
172+
)
173+
174+
if not (check_for_block_diag(x) or check_for_block_diag(y)):
175+
return None
176+
177+
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
178+
# non-block diagonal, and return a new block diagonal
179+
if check_for_block_diag(x) and not check_for_block_diag(y):
180+
components = x.owner.inputs
181+
y_splits = split(
182+
y,
183+
splits_size=[component.shape[-1] for component in components],
184+
n_splits=len(components),
185+
)
186+
new_components = [
187+
op(component, y_split) for component, y_split in zip(components, y_splits)
188+
]
189+
new_output = join(0, *new_components)
190+
elif not check_for_block_diag(x) and check_for_block_diag(y):
191+
components = y.owner.inputs
192+
new_components = [op(x, component) for component in components]
193+
new_output = join(0, *new_components)
194+
195+
# Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
196+
# that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
197+
elif any(shape is None for shape in (*x.type.shape, *y.type.shape)):
198+
return None
199+
elif x.ndim == y.ndim and all(
200+
x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape)
201+
):
202+
x_components = x.owner.inputs
203+
y_components = y.owner.inputs
204+
205+
if len(x_components) != len(y_components):
206+
return None
207+
208+
new_output = BlockDiagonal(len(x_components))(
209+
*[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)]
210+
)
211+
else:
212+
return None
213+
214+
copy_stack_trace(node.outputs[0], new_output)
215+
return [new_output]
216+
217+
149218
@register_canonicalize
150219
@node_rewriter([Dot, _matmul])
151220
def local_lift_transpose_through_dot(fgraph, node):
@@ -2582,7 +2651,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
25822651
name="add_canonizer_group",
25832652
)
25842653

2585-
25862654
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
25872655

25882656

@@ -3720,7 +3788,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37203788
)
37213789
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
37223790

3723-
37243791
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
37253792
# i.e logit(sigmoid(x)) -> x
37263793
local_logit_sigmoid = PatternNodeRewriter(
@@ -3734,7 +3801,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37343801
register_canonicalize(local_logit_sigmoid)
37353802
register_specialize(local_logit_sigmoid)
37363803

3737-
37383804
# sigmoid(log(x / (1-x)) -> x
37393805
# i.e., sigmoid(logit(x)) -> x
37403806
local_sigmoid_logit = PatternNodeRewriter(
@@ -3775,7 +3841,6 @@ def local_useless_conj(fgraph, node):
37753841

37763842
register_specialize(local_polygamma_to_tri_gamma)
37773843

3778-
37793844
local_log_kv = PatternNodeRewriter(
37803845
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
37813846
# During stabilize -x is converted to -1.0 * x

tests/tensor/rewriting/test_math.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
simplify_mul,
116116
)
117117
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
118+
from pytensor.tensor.slinalg import BlockDiagonal
118119
from pytensor.tensor.type import (
119120
TensorType,
120121
cmatrix,
@@ -4745,3 +4746,75 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47454746
out.eval({a: a_test, b: b_test}, mode=test_mode),
47464747
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
47474748
)
4749+
4750+
4751+
def test_local_block_diag_dot_to_dot_block_diag():
4752+
"""
4753+
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4754+
"""
4755+
a = tensor("a", shape=(4, 2))
4756+
b = tensor("b", shape=(2, 4))
4757+
c = tensor("c", shape=(4, 4))
4758+
d = tensor("d", shape=(10,))
4759+
4760+
x = pt.linalg.block_diag(a, b, c)
4761+
out = x @ d
4762+
4763+
fn = pytensor.function([a, b, c, d], out)
4764+
assert not any(
4765+
isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort()
4766+
)
4767+
4768+
fn_expected = pytensor.function(
4769+
[a, b, c, d],
4770+
out,
4771+
mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"),
4772+
)
4773+
4774+
rng = np.random.default_rng()
4775+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4776+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4777+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4778+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4779+
4780+
np.testing.assert_allclose(
4781+
fn(a_val, b_val, c_val, d_val),
4782+
fn_expected(a_val, b_val, c_val, d_val),
4783+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4784+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4785+
)
4786+
4787+
4788+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
4789+
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
4790+
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
4791+
rng = np.random.default_rng()
4792+
a_size = int(rng.uniform(0, size))
4793+
b_size = int(rng.uniform(0, size - a_size))
4794+
c_size = size - a_size - b_size
4795+
4796+
a = tensor("a", shape=(a_size, a_size))
4797+
b = tensor("b", shape=(b_size, b_size))
4798+
c = tensor("c", shape=(c_size, c_size))
4799+
d = tensor("d", shape=(size,))
4800+
4801+
x = pt.linalg.block_diag(a, b, c)
4802+
out = x @ d
4803+
4804+
mode = get_default_mode()
4805+
if not rewrite:
4806+
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
4807+
fn = pytensor.function([a, b, c, d], out, mode=mode)
4808+
4809+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4810+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4811+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4812+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4813+
4814+
benchmark(
4815+
fn,
4816+
a_val,
4817+
b_val,
4818+
c_val,
4819+
d_val,
4820+
)

0 commit comments

Comments
 (0)