Skip to content

Commit 9455b86

Browse files
pair coding results
1 parent 7cef064 commit 9455b86

File tree

3 files changed

+52
-55
lines changed

3 files changed

+52
-55
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
broadcasted_by,
4242
register_canonicalize,
4343
register_specialize,
44+
register_stabilize,
4445
)
4546
from pytensor.tensor.shape import shape_padleft
4647
from pytensor.tensor.variable import TensorConstant
@@ -395,6 +396,7 @@ def is_dimshuffle_useless(new_order, input):
395396

396397

397398
@register_canonicalize
399+
@register_stabilize
398400
@register_specialize
399401
@node_rewriter([DimShuffle])
400402
def local_dimshuffle_lift(fgraph, node):

pytensor/tensor/rewriting/math.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -183,59 +183,40 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
183183
if not isinstance(node.op.core_op, BlockDiagonal):
184184
return
185185

186-
def check_for_block_diag(x):
187-
return x.owner and (
188-
isinstance(x.owner.op, BlockDiagonal)
189-
or isinstance(x.owner.op, Blockwise)
190-
and isinstance(x.owner.op.core_op, BlockDiagonal)
191-
)
192-
193186
# Check that the BlockDiagonal is an input to a Dot node:
194187
for client in get_clients_at_depth(fgraph, node, depth=1):
195-
if not isinstance(client.op, Dot):
188+
if not (
189+
(
190+
isinstance(client.op, Dot)
191+
and all(input.ndim == 2 for input in client.inputs)
192+
)
193+
or client.op == _matrix_matrix_matmul
194+
):
196195
continue
197196

198197
op = client.op
199-
x, y = client.inputs
200198

201-
if not (check_for_block_diag(x) or check_for_block_diag(y)):
202-
continue
199+
client_idx = client.inputs.index(node.outputs[0])
203200

204-
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
205-
# non-block diagonal, and return a new block diagonal
206-
if check_for_block_diag(x) and not check_for_block_diag(y):
207-
components = x.owner.inputs
208-
y_splits = split(
209-
y,
210-
splits_size=[component.shape[-1] for component in components],
211-
n_splits=len(components),
212-
)
213-
new_components = [
214-
op(component, y_split)
215-
for component, y_split in zip(components, y_splits)
216-
]
217-
new_output = join(0, *new_components)
218-
219-
elif not check_for_block_diag(x) and check_for_block_diag(y):
220-
components = y.owner.inputs
221-
x_splits = split(
222-
x,
223-
splits_size=[component.shape[0] for component in components],
224-
n_splits=len(components),
225-
axis=1,
226-
)
201+
other_input = client.inputs[1 - client_idx]
202+
components = node.inputs
227203

228-
new_components = [
229-
op(x_split, component)
230-
for component, x_split in zip(components, x_splits)
231-
]
232-
new_output = join(1, *new_components)
204+
split_axis = -2 if client_idx == 0 else -1
205+
shape_idx = -1 if client_idx == 0 else -2
233206

234-
# Case 2: Both inputs are BlockDiagonal. Do nothing
235-
else:
236-
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
237-
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
238-
continue
207+
other_dot_input_split = split(
208+
other_input,
209+
splits_size=[component.shape[shape_idx] for component in components],
210+
n_splits=len(components),
211+
axis=split_axis,
212+
)
213+
new_components = [
214+
op(component, other_split)
215+
if client_idx == 0
216+
else op(other_split, component)
217+
for component, other_split in zip(components, other_dot_input_split)
218+
]
219+
new_output = join(split_axis, *new_components)
239220

240221
copy_stack_trace(node.outputs[0], new_output)
241222
return {client.outputs[0]: new_output}

tests/tensor/rewriting/test_math.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4658,15 +4658,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
46584658

46594659

46604660
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4661-
def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
4661+
@pytest.mark.parametrize(
4662+
"batch_left", [True, False], ids=["batched_left", "unbatched_left"]
4663+
)
4664+
@pytest.mark.parametrize(
4665+
"batch_right", [True, False], ids=["batched_right", "unbatched_right"]
4666+
)
4667+
def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right):
46624668
"""
46634669
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
46644670
"""
46654671
a = tensor("a", shape=(4, 2))
4666-
b = tensor("b", shape=(2, 4))
4672+
b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4))
46674673
c = tensor("c", shape=(4, 4))
46684674
d = tensor("d", shape=(10, 10))
4669-
e = tensor("e", shape=(10, 10))
4675+
e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10))
46704676

46714677
x = pt.linalg.block_diag(a, b, c)
46724678

@@ -4676,30 +4682,38 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
46764682
else:
46774683
out = [d @ x, e @ x]
46784684

4679-
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4685+
with config.change_flags(optimizer_verbose=True):
4686+
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4687+
46804688
assert not any(
46814689
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
46824690
)
46834691

46844692
fn_expected = pytensor.function(
46854693
[a, b, c, d, e],
46864694
out,
4687-
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
4695+
mode=Mode(linker="py", optimizer=None),
46884696
)
46894697

4698+
# TODO: Count Dots
4699+
46904700
rng = np.random.default_rng()
46914701
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
46924702
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
46934703
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
46944704
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
46954705
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
46964706

4697-
np.testing.assert_allclose(
4698-
fn(a_val, b_val, c_val, d_val, e_val),
4699-
fn_expected(a_val, b_val, c_val, d_val, e_val),
4700-
atol=1e-6 if config.floatX == "float32" else 1e-12,
4701-
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4702-
)
4707+
rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val)
4708+
expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val)
4709+
4710+
for out, expected in zip(rewrite_outs, expected_outs):
4711+
np.testing.assert_allclose(
4712+
out,
4713+
expected,
4714+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4715+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4716+
)
47034717

47044718

47054719
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])

0 commit comments

Comments
 (0)