29
29
constant ,
30
30
expand_dims ,
31
31
get_underlying_scalar_constant_value ,
32
+ join ,
32
33
moveaxis ,
33
34
ones_like ,
34
35
register_infer_shape ,
36
+ split ,
35
37
switch ,
36
38
zeros ,
37
39
zeros_like ,
96
98
from pytensor .tensor .rewriting .elemwise import apply_local_dimshuffle_lift
97
99
from pytensor .tensor .rewriting .linalg import is_matrix_transpose
98
100
from pytensor .tensor .shape import Shape , Shape_i
101
+ from pytensor .tensor .slinalg import BlockDiagonal
99
102
from pytensor .tensor .subtensor import Subtensor
100
103
from pytensor .tensor .type import (
101
104
complex_dtypes ,
@@ -146,6 +149,72 @@ def local_0_dot_x(fgraph, node):
146
149
return [zeros ((x .shape [0 ], y .shape [1 ]), dtype = node .outputs [0 ].type .dtype )]
147
150
148
151
152
+ @register_canonicalize
153
+ @register_specialize
154
+ @register_stabilize
155
+ @node_rewriter ([Dot ])
156
+ def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
157
+ r"""
158
+ Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
159
+
160
+ BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
161
+ of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
162
+ a single dot on the larger matrix.
163
+ """
164
+ x , y = node .inputs
165
+ op = node .op
166
+
167
+ def check_for_block_diag (x ):
168
+ return x .owner and (
169
+ isinstance (x .owner .op , BlockDiagonal )
170
+ or isinstance (x .owner .op , Blockwise )
171
+ and isinstance (x .owner .op .core_op , BlockDiagonal )
172
+ )
173
+
174
+ if not (check_for_block_diag (x ) or check_for_block_diag (y )):
175
+ return None
176
+
177
+ # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
178
+ # non-block diagonal, and return a new block diagonal
179
+ if check_for_block_diag (x ) and not check_for_block_diag (y ):
180
+ components = x .owner .inputs
181
+ y_splits = split (
182
+ y ,
183
+ splits_size = [component .shape [- 1 ] for component in components ],
184
+ n_splits = len (components ),
185
+ )
186
+ new_components = [
187
+ op (component , y_split ) for component , y_split in zip (components , y_splits )
188
+ ]
189
+ new_output = join (0 , * new_components )
190
+ elif not check_for_block_diag (x ) and check_for_block_diag (y ):
191
+ components = y .owner .inputs
192
+ new_components = [op (x , component ) for component in components ]
193
+ new_output = join (0 , * new_components )
194
+
195
+ # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
196
+ # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
197
+ elif any (shape is None for shape in (* x .type .shape , * y .type .shape )):
198
+ return None
199
+ elif x .ndim == y .ndim and all (
200
+ x_shape == y_shape for x_shape , y_shape in zip (x .type .shape , y .type .shape )
201
+ ):
202
+ x_components = x .owner .inputs
203
+ y_components = y .owner .inputs
204
+
205
+ if len (x_components ) != len (y_components ):
206
+ return None
207
+
208
+ new_output = BlockDiagonal (len (x_components ))(
209
+ * [op (x_comp , y_comp ) for x_comp , y_comp in zip (x_components , y_components )]
210
+ )
211
+ else :
212
+ return None
213
+
214
+ copy_stack_trace (node .outputs [0 ], new_output )
215
+ return [new_output ]
216
+
217
+
149
218
@register_canonicalize
150
219
@node_rewriter ([Dot , _matmul ])
151
220
def local_lift_transpose_through_dot (fgraph , node ):
@@ -2582,7 +2651,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
2582
2651
name = "add_canonizer_group" ,
2583
2652
)
2584
2653
2585
-
2586
2654
register_canonicalize (local_add_canonizer , "shape_unsafe" , name = "local_add_canonizer" )
2587
2655
2588
2656
@@ -3720,7 +3788,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
3720
3788
)
3721
3789
register_stabilize (logdiffexp_to_log1mexpdiff , name = "logdiffexp_to_log1mexpdiff" )
3722
3790
3723
-
3724
3791
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
3725
3792
# i.e logit(sigmoid(x)) -> x
3726
3793
local_logit_sigmoid = PatternNodeRewriter (
@@ -3734,7 +3801,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
3734
3801
register_canonicalize (local_logit_sigmoid )
3735
3802
register_specialize (local_logit_sigmoid )
3736
3803
3737
-
3738
3804
# sigmoid(log(x / (1-x)) -> x
3739
3805
# i.e., sigmoid(logit(x)) -> x
3740
3806
local_sigmoid_logit = PatternNodeRewriter (
@@ -3775,7 +3841,6 @@ def local_useless_conj(fgraph, node):
3775
3841
3776
3842
register_specialize (local_polygamma_to_tri_gamma )
3777
3843
3778
-
3779
3844
local_log_kv = PatternNodeRewriter (
3780
3845
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
3781
3846
# During stabilize -x is converted to -1.0 * x
0 commit comments