@@ -4750,41 +4750,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47504750
47514751@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
47524752@pytest .mark .parametrize (
4753- "batch_left " , [True , False ], ids = ["batched_left " , "unbatched_left " ]
4753+ "batch_blockdiag " , [True , False ], ids = ["batch_blockdiag " , "unbatched_blockdiag " ]
47544754)
47554755@pytest .mark .parametrize (
4756- "batch_right " , [True , False ], ids = ["batched_right " , "unbatched_right " ]
4756+ "batch_other " , [True , False ], ids = ["batched_other " , "unbatched_other " ]
47574757)
4758- def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
4758+ def test_local_block_diag_dot_to_dot_block_diag (
4759+ left_multiply , batch_blockdiag , batch_other
4760+ ):
47594761 """
47604762 Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
47614763 """
4764+
4765+ def has_blockdiag (graph ):
4766+ return any (
4767+ (
4768+ var .owner
4769+ and (
4770+ isinstance (var .owner .op , BlockDiagonal )
4771+ or (
4772+ isinstance (var .owner .op , Blockwise )
4773+ and isinstance (var .owner .op .core_op , BlockDiagonal )
4774+ )
4775+ )
4776+ )
4777+ for var in ancestors ([graph ])
4778+ )
4779+
47624780 a = tensor ("a" , shape = (4 , 2 ))
4763- b = tensor ("b" , shape = (2 , 4 ) if not batch_left else (3 , 2 , 4 ))
4781+ b = tensor ("b" , shape = (2 , 4 ) if not batch_blockdiag else (3 , 2 , 4 ))
47644782 c = tensor ("c" , shape = (4 , 4 ))
4765- d = tensor ("d" , shape = (10 , 10 ))
4766- e = tensor ("e" , shape = (10 , 10 ) if not batch_right else (3 , 1 , 10 , 10 ))
4767-
47684783 x = pt .linalg .block_diag (a , b , c )
47694784
4785+ d = tensor ("d" , shape = (10 , 10 ) if not batch_other else (3 , 1 , 10 , 10 ))
4786+
47704787 # Test multiple clients are all rewritten
47714788 if left_multiply :
4772- out = [ x @ d , x @ e ]
4789+ out = x @ d
47734790 else :
4774- out = [ d @ x , e @ x ]
4791+ out = d @ x
47754792
4776- with config .change_flags (optimizer_verbose = True ):
4777- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4778-
4779- assert not any (
4780- isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
4781- )
4793+ assert has_blockdiag (out )
4794+ fn = pytensor .function ([a , b , c , d ], out , mode = rewrite_mode )
4795+ assert not has_blockdiag (fn .maker .fgraph .outputs [0 ])
47824796
47834797 fn_expected = pytensor .function (
4784- [a , b , c , d , e ],
4798+ [a , b , c , d ],
47854799 out ,
47864800 mode = Mode (linker = "py" , optimizer = None ),
47874801 )
4802+ assert has_blockdiag (fn_expected .maker .fgraph .outputs [0 ])
47884803
47894804 # TODO: Count Dots
47904805
@@ -4793,18 +4808,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
47934808 b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
47944809 c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
47954810 d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
4796- e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
47974811
4798- rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4799- expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4800-
4801- for out , expected in zip (rewrite_outs , expected_outs ):
4802- np .testing .assert_allclose (
4803- out ,
4804- expected ,
4805- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4806- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4807- )
4812+ rewrite_out = fn (a_val , b_val , c_val , d_val )
4813+ expected_out = fn_expected (a_val , b_val , c_val , d_val )
4814+ np .testing .assert_allclose (
4815+ rewrite_out ,
4816+ expected_out ,
4817+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4818+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4819+ )
48084820
48094821
48104822@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments