|
7 | 7 | TODO: Automatic methods for determining best sparse format?
|
8 | 8 |
|
9 | 9 | """
|
| 10 | +from typing import Literal |
10 | 11 | from warnings import warn
|
11 | 12 |
|
12 | 13 | import numpy as np
|
|
47 | 48 | trunc,
|
48 | 49 | )
|
49 | 50 | from pytensor.tensor.shape import shape, specify_broadcastable
|
| 51 | +from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype |
50 | 52 | from pytensor.tensor.type import TensorType
|
51 | 53 | from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
|
52 | 54 | from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
|
|
60 | 62 |
|
61 | 63 | sparse_formats = ["csc", "csr"]
|
62 | 64 |
|
63 |
| - |
64 | 65 | """
|
65 | 66 | Types of sparse matrices to use for testing.
|
66 | 67 |
|
@@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
|
183 | 184 |
|
184 | 185 | as_sparse = as_sparse_variable
|
185 | 186 |
|
186 |
| - |
187 | 187 | as_sparse_or_tensor_variable = as_symbolic
|
188 | 188 |
|
189 | 189 |
|
@@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
|
1800 | 1800 | return r
|
1801 | 1801 |
|
1802 | 1802 | def __str__(self):
|
1803 |
| - return f"{self.__class__.__name__ }{{axis={self.axis}}}" |
| 1803 | + return f"{self.__class__.__name__}{{axis={self.axis}}}" |
1804 | 1804 |
|
1805 | 1805 |
|
1806 | 1806 | def sp_sum(x, axis=None, sparse_grad=False):
|
@@ -2775,19 +2775,14 @@ def comparison(self, x, y):
|
2775 | 2775 |
|
2776 | 2776 | greater_equal_s_d = GreaterEqualSD()
|
2777 | 2777 |
|
2778 |
| - |
2779 | 2778 | eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
|
2780 | 2779 |
|
2781 |
| - |
2782 | 2780 | neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
|
2783 | 2781 |
|
2784 |
| - |
2785 | 2782 | lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
|
2786 | 2783 |
|
2787 |
| - |
2788 | 2784 | gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
|
2789 | 2785 |
|
2790 |
| - |
2791 | 2786 | le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
|
2792 | 2787 |
|
2793 | 2788 | ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
|
@@ -2992,7 +2987,7 @@ def __str__(self):
|
2992 | 2987 | l = []
|
2993 | 2988 | if self.inplace:
|
2994 | 2989 | l.append("inplace")
|
2995 |
| - return f"{self.__class__.__name__ }{{{', '.join(l)}}}" |
| 2990 | + return f"{self.__class__.__name__}{{{', '.join(l)}}}" |
2996 | 2991 |
|
2997 | 2992 | def make_node(self, x):
|
2998 | 2993 | """
|
@@ -3291,6 +3286,7 @@ class TrueDot(Op):
|
3291 | 3286 | # Simplify code by splitting into DotSS and DotSD.
|
3292 | 3287 |
|
3293 | 3288 | __props__ = ()
|
| 3289 | + |
3294 | 3290 | # The grad_preserves_dense attribute doesn't change the
|
3295 | 3291 | # execution behavior. To let the optimizer merge nodes with
|
3296 | 3292 | # different values of this attribute we shouldn't compare it
|
@@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):
|
4260 | 4256 |
|
4261 | 4257 |
|
4262 | 4258 | construct_sparse_from_list = ConstructSparseFromList()
|
| 4259 | + |
| 4260 | + |
| 4261 | +class SparseBlockDiagonal(BaseBlockDiagonal): |
| 4262 | + __props__ = ( |
| 4263 | + "n_inputs", |
| 4264 | + "format", |
| 4265 | + ) |
| 4266 | + |
| 4267 | + def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): |
| 4268 | + super().__init__(n_inputs) |
| 4269 | + self.format = format |
| 4270 | + |
| 4271 | + def make_node(self, *matrices): |
| 4272 | + matrices = self._validate_and_prepare_inputs( |
| 4273 | + matrices, as_sparse_or_tensor_variable |
| 4274 | + ) |
| 4275 | + dtype = _largest_common_dtype(matrices) |
| 4276 | + out_type = matrix(format=self.format, dtype=dtype) |
| 4277 | + |
| 4278 | + return Apply(self, matrices, [out_type]) |
| 4279 | + |
| 4280 | + def perform(self, node, inputs, output_storage, params=None): |
| 4281 | + dtype = node.outputs[0].type.dtype |
| 4282 | + output_storage[0][0] = scipy.sparse.block_diag( |
| 4283 | + inputs, format=self.format |
| 4284 | + ).astype(dtype) |
| 4285 | + |
| 4286 | + |
| 4287 | +def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"): |
| 4288 | + r""" |
| 4289 | + Construct a block diagonal matrix from a sequence of input matrices. |
| 4290 | +
|
| 4291 | + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: |
| 4292 | +
|
| 4293 | + [[A, 0, 0], |
| 4294 | + [0, B, 0], |
| 4295 | + [0, 0, C]] |
| 4296 | +
|
| 4297 | + Parameters |
| 4298 | + ---------- |
| 4299 | + A, B, C ... : tensors |
| 4300 | + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all |
| 4301 | + inputs should have at least 2 dimensins. |
| 4302 | +
|
| 4303 | + Note that the input matrices need not be sparse themselves, and will be automatically converted to the |
| 4304 | + requested format if they are not. |
| 4305 | +
|
| 4306 | + format: str, optional |
| 4307 | + The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. |
| 4308 | +
|
| 4309 | + Returns |
| 4310 | + ------- |
| 4311 | + out: sparse matrix tensor |
| 4312 | + Symbolic sparse matrix in the specified format. |
| 4313 | +
|
| 4314 | + Examples |
| 4315 | + -------- |
| 4316 | + Create a sparse block diagonal matrix from two sparse 2x2 matrices: |
| 4317 | +
|
| 4318 | + ..code-block:: python |
| 4319 | + import numpy as np |
| 4320 | + from pytensor.sparse import block_diag |
| 4321 | + from scipy.sparse import csr_matrix |
| 4322 | +
|
| 4323 | + A = csr_matrix([[1, 2], [3, 4]]) |
| 4324 | + B = csr_matrix([[5, 6], [7, 8]]) |
| 4325 | + result_sparse = block_diag(A, B, format='csr', name='X') |
| 4326 | +
|
| 4327 | + print(result_sparse) |
| 4328 | + >>> SparseVariable{csr,int32} |
| 4329 | +
|
| 4330 | + print(result_sparse.toarray().eval()) |
| 4331 | + >>> array([[1, 2, 0, 0], |
| 4332 | + >>> [3, 4, 0, 0], |
| 4333 | + >>> [0, 0, 5, 6], |
| 4334 | + >>> [0, 0, 7, 8]]) |
| 4335 | + """ |
| 4336 | + if len(matrices) == 1: |
| 4337 | + return matrices |
| 4338 | + |
| 4339 | + _sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format) |
| 4340 | + return _sparse_block_diagonal(*matrices) |
0 commit comments