29
29
cast ,
30
30
constant ,
31
31
get_underlying_scalar_constant_value ,
32
+ join ,
32
33
moveaxis ,
33
34
ones_like ,
34
35
register_infer_shape ,
36
+ split ,
35
37
switch ,
36
38
zeros_like ,
37
39
)
99
101
)
100
102
from pytensor .tensor .rewriting .elemwise import apply_local_dimshuffle_lift
101
103
from pytensor .tensor .shape import Shape , Shape_i
104
+ from pytensor .tensor .slinalg import BlockDiagonal
102
105
from pytensor .tensor .subtensor import Subtensor
103
106
from pytensor .tensor .type import (
104
107
complex_dtypes ,
@@ -167,6 +170,72 @@ def local_0_dot_x(fgraph, node):
167
170
return [constant_zero ]
168
171
169
172
173
+ @register_canonicalize
174
+ @register_specialize
175
+ @register_stabilize
176
+ @node_rewriter ([Dot ])
177
+ def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
178
+ r"""
179
+ Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))``
180
+
181
+ BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
182
+ of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
183
+ a single dot on the larger matrix.
184
+ """
185
+ x , y = node .inputs
186
+ op = node .op
187
+
188
+ def check_for_block_diag (x ):
189
+ return x .owner and (
190
+ isinstance (x .owner .op , BlockDiagonal )
191
+ or isinstance (x .owner .op , Blockwise )
192
+ and isinstance (x .owner .op .core_op , BlockDiagonal )
193
+ )
194
+
195
+ if not (check_for_block_diag (x ) or check_for_block_diag (y )):
196
+ return None
197
+
198
+ # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
199
+ # non-block diagonal, and return a new block diagonal
200
+ if check_for_block_diag (x ) and not check_for_block_diag (y ):
201
+ components = x .owner .inputs
202
+ y_splits = split (
203
+ y ,
204
+ splits_size = [component .shape [- 1 ] for component in components ],
205
+ n_splits = len (components ),
206
+ )
207
+ new_components = [
208
+ op (component , y_split ) for component , y_split in zip (components , y_splits )
209
+ ]
210
+ new_output = join (0 , * new_components )
211
+ elif not check_for_block_diag (x ) and check_for_block_diag (y ):
212
+ components = y .owner .inputs
213
+ new_components = [op (x , component ) for component in components ]
214
+ new_output = join (0 , * new_components )
215
+
216
+ # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In
217
+ # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case
218
+ elif any (shape is None for shape in (* x .type .shape , * y .type .shape )):
219
+ return None
220
+ elif x .ndim == y .ndim and all (
221
+ x_shape == y_shape for x_shape , y_shape in zip (x .type .shape , y .type .shape )
222
+ ):
223
+ x_components = x .owner .inputs
224
+ y_components = y .owner .inputs
225
+
226
+ if len (x_components ) != len (y_components ):
227
+ return None
228
+
229
+ new_output = BlockDiagonal (len (x_components ))(
230
+ * [op (x_comp , y_comp ) for x_comp , y_comp in zip (x_components , y_components )]
231
+ )
232
+ else :
233
+ return None
234
+
235
+ copy_stack_trace (node .outputs [0 ], new_output )
236
+ return [new_output ]
237
+
238
+
170
239
@register_canonicalize
171
240
@node_rewriter ([DimShuffle ])
172
241
def local_lift_transpose_through_dot (fgraph , node ):
@@ -2496,7 +2565,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
2496
2565
name = "add_canonizer_group" ,
2497
2566
)
2498
2567
2499
-
2500
2568
register_canonicalize (local_add_canonizer , "shape_unsafe" , name = "local_add_canonizer" )
2501
2569
2502
2570
@@ -3619,7 +3687,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
3619
3687
)
3620
3688
register_stabilize (logdiffexp_to_log1mexpdiff , name = "logdiffexp_to_log1mexpdiff" )
3621
3689
3622
-
3623
3690
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
3624
3691
# i.e logit(sigmoid(x)) -> x
3625
3692
local_logit_sigmoid = PatternNodeRewriter (
@@ -3633,7 +3700,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
3633
3700
register_canonicalize (local_logit_sigmoid )
3634
3701
register_specialize (local_logit_sigmoid )
3635
3702
3636
-
3637
3703
# sigmoid(log(x / (1-x)) -> x
3638
3704
# i.e., sigmoid(logit(x)) -> x
3639
3705
local_sigmoid_logit = PatternNodeRewriter (
@@ -3674,7 +3740,6 @@ def local_useless_conj(fgraph, node):
3674
3740
3675
3741
register_specialize (local_polygamma_to_tri_gamma )
3676
3742
3677
-
3678
3743
local_log_kv = PatternNodeRewriter (
3679
3744
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
3680
3745
# During stabilize -x is converted to -1.0 * x
0 commit comments