@@ -4749,15 +4749,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47494749
47504750
47514751@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
4752- def test_local_block_diag_dot_to_dot_block_diag (left_multiply ):
4752+ @pytest .mark .parametrize (
4753+ "batch_left" , [True , False ], ids = ["batched_left" , "unbatched_left" ]
4754+ )
4755+ @pytest .mark .parametrize (
4756+ "batch_right" , [True , False ], ids = ["batched_right" , "unbatched_right" ]
4757+ )
4758+ def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
47534759 """
47544760 Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
47554761 """
47564762 a = tensor ("a" , shape = (4 , 2 ))
4757- b = tensor ("b" , shape = (2 , 4 ))
4763+ b = tensor ("b" , shape = (2 , 4 ) if not batch_left else ( 3 , 2 , 4 ) )
47584764 c = tensor ("c" , shape = (4 , 4 ))
47594765 d = tensor ("d" , shape = (10 , 10 ))
4760- e = tensor ("e" , shape = (10 , 10 ))
4766+ e = tensor ("e" , shape = (10 , 10 ) if not batch_right else ( 3 , 1 , 10 , 10 ) )
47614767
47624768 x = pt .linalg .block_diag (a , b , c )
47634769
@@ -4767,30 +4773,38 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47674773 else :
47684774 out = [d @ x , e @ x ]
47694775
4770- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4776+ with config .change_flags (optimizer_verbose = True ):
4777+ fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4778+
47714779 assert not any (
47724780 isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
47734781 )
47744782
47754783 fn_expected = pytensor .function (
47764784 [a , b , c , d , e ],
47774785 out ,
4778- mode = rewrite_mode . excluding ( "local_block_diag_dot_to_dot_block_diag" ),
4786+ mode = Mode ( linker = "py" , optimizer = None ),
47794787 )
47804788
4789+ # TODO: Count Dots
4790+
47814791 rng = np .random .default_rng ()
47824792 a_val = rng .normal (size = a .type .shape ).astype (a .type .dtype )
47834793 b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
47844794 c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
47854795 d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
47864796 e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
47874797
4788- np .testing .assert_allclose (
4789- fn (a_val , b_val , c_val , d_val , e_val ),
4790- fn_expected (a_val , b_val , c_val , d_val , e_val ),
4791- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4792- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4793- )
4798+ rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4799+ expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4800+
4801+ for out , expected in zip (rewrite_outs , expected_outs ):
4802+ np .testing .assert_allclose (
4803+ out ,
4804+ expected ,
4805+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4806+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4807+ )
47944808
47954809
47964810@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments