@@ -1729,7 +1729,10 @@ def check_qr_stacked(self, a):
17291729 assert_ (q .shape [- 2 :] == (m , m ))
17301730 assert_ (r .shape [- 2 :] == (m , n ))
17311731 assert_almost_equal (matmul (q , r ), a )
1732- assert_almost_equal (swapaxes (q , - 1 , - 2 ).conj (), np .linalg .inv (q ))
1732+ I_mat = np .identity (q .shape [- 1 ])
1733+ stack_I_mat = np .broadcast_to (I_mat ,
1734+ q .shape [:- 2 ] + (q .shape [- 1 ],)* 2 )
1735+ assert_almost_equal (matmul (swapaxes (q , - 1 , - 2 ).conj (), q ), stack_I_mat )
17331736 assert_almost_equal (np .triu (r [..., :, :]), r )
17341737
17351738 # mode == 'reduced'
@@ -1741,7 +1744,10 @@ def check_qr_stacked(self, a):
17411744 assert_ (q1 .shape [- 2 :] == (m , k ))
17421745 assert_ (r1 .shape [- 2 :] == (k , n ))
17431746 assert_almost_equal (matmul (q1 , r1 ), a )
1744- assert_almost_equal (swapaxes (q , - 1 , - 2 ).conj (), np .linalg .inv (q ))
1747+ I_mat = np .identity (q1 .shape [- 1 ])
1748+ stack_I_mat = np .broadcast_to (I_mat ,
1749+ q1 .shape [:- 2 ] + (q1 .shape [- 1 ],)* 2 )
1750+ assert_almost_equal (matmul (swapaxes (q1 , - 1 , - 2 ).conj (), q1 ), stack_I_mat )
17451751 assert_almost_equal (np .triu (r1 [..., :, :]), r1 )
17461752
17471753 # mode == 'r'
@@ -1750,22 +1756,20 @@ def check_qr_stacked(self, a):
17501756 assert_ (isinstance (r2 , a_type ))
17511757 assert_almost_equal (r2 , r1 )
17521758
1753- def test_stacked_inputs (self ):
1754-
1755- normal = np .random .normal
1756- sizes = [(3 , 4 ), (4 , 3 ), (4 , 4 ), (3 , 0 ), (0 , 3 )]
1757- dts = [np .float32 , np .float64 , np .complex64 ]
1758- for size in sizes :
1759- for dt in dts :
1760- a1 , a2 , a3 , a4 = [normal (size = size ), normal (size = size ),
1761- normal (size = size ), normal (size = size )]
1762- b1 , b2 , b3 , b4 = [normal (size = size ), normal (size = size ),
1763- normal (size = size ), normal (size = size )]
1764- A = np .asarray ([[a1 , a2 ], [a3 , a4 ]], dtype = dt )
1765- B = np .asarray ([[b1 , b2 ], [b3 , b4 ]], dtype = dt )
1766- self .check_qr_stacked (A )
1767- self .check_qr_stacked (B )
1768- self .check_qr_stacked (A + 1.j * B )
1759+ @pytest .mark .parametrize ("size" , [
1760+ (3 , 4 ), (4 , 3 ), (4 , 4 ),
1761+ (3 , 0 ), (0 , 3 )])
1762+ @pytest .mark .parametrize ("outer_size" , [
1763+ (2 , 2 ), (2 ,), (2 , 3 , 4 )])
1764+ @pytest .mark .parametrize ("dt" , [
1765+ np .single , np .double ,
1766+ np .csingle , np .cdouble ])
1767+ def test_stacked_inputs (self , outer_size , size , dt ):
1768+
1769+ A = np .random .normal (size = outer_size + size ).astype (dt )
1770+ B = np .random .normal (size = outer_size + size ).astype (dt )
1771+ self .check_qr_stacked (A )
1772+ self .check_qr_stacked (A + 1.j * B )
17691773
17701774
17711775class TestCholesky :
0 commit comments