From a3e9f8995acd2c6aeaa7932c4c725de1c6249840 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 8 Jul 2025 22:36:49 +0800 Subject: [PATCH] Preserve static shape information in `block_diag` --- pytensor/tensor/slinalg.py | 13 ++++++++++++- tests/tensor/test_slinalg.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 1ad427c0a9..946abbb0d6 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1651,7 +1651,18 @@ class BlockDiagonal(BaseBlockDiagonal): def make_node(self, *matrices): matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor) dtype = _largest_common_dtype(matrices) - out_type = pytensor.tensor.matrix(dtype=dtype) + + shapes_by_dim = tuple(zip(*(m.type.shape for m in matrices))) + out_shape = tuple( + [ + sum(dim_shapes) + if not any(shape is None for shape in dim_shapes) + else None + for dim_shapes in shapes_by_dim + ] + ) + + out_type = pytensor.tensor.matrix(shape=out_shape, dtype=dtype) return Apply(self, matrices, [out_type]) def perform(self, node, inputs, output_storage, params=None): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index b7a5fbb510..8b48c33a3c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1040,11 +1040,28 @@ def test_block_diagonal(): A = np.array([[1.0, 2.0], [3.0, 4.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]]) result = block_diag(A, B) + assert result.type.shape == (4, 4) assert result.owner.op.core_op._props_dict() == {"n_inputs": 2} np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) +def test_block_diagonal_static_shape(): + A = pt.dmatrix("A", shape=(5, 5)) + B = pt.dmatrix("B", shape=(3, 10)) + result = block_diag(A, B) + assert result.type.shape == (8, 15) + + A = pt.dmatrix("A", shape=(5, 5)) + B = pt.dmatrix("B", shape=(3, None)) + result = block_diag(A, B) + assert result.type.shape == (8, None) + + A = pt.dmatrix("A", shape=(None, 5)) + result = block_diag(A, B) + assert result.type.shape == (None, None) + + def test_block_diagonal_grad(): A = np.array([[1.0, 2.0], [3.0, 4.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]])