@@ -152,16 +152,56 @@ def test_qr_modes():
152152 assert "name 'complete' is not defined" in str (e )
153153
154154
155- @pytest .mark .parametrize ("shape" , [(3 , 3 ), (6 , 3 )], ids = ["shape=(3, 3)" , "shape=(6,3)" ])
156- @pytest .mark .parametrize ("output" , [0 , 1 ], ids = ["Q" , "R" ])
157- def test_qr_grad (shape , output ):
155+ @pytest .mark .parametrize (
156+ "shape, gradient_test_case, mode" ,
157+ (
158+ [(s , c , "reduced" ) for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )] for c in [0 , 1 , 2 ]]
159+ + [(s , c , "complete" ) for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )] for c in [0 , 1 , 2 ]]
160+ + [(s , 0 , "r" ) for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )]]
161+ + [((3 , 3 ), 0 , "raw" )]
162+ ),
163+ ids = (
164+ [
165+ f"shape={ s } , gradient_test_case={ c } , mode=reduced"
166+ for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )]
167+ for c in ["Q" , "R" , "both" ]
168+ ]
169+ + [
170+ f"shape={ s } , gradient_test_case={ c } , mode=complete"
171+ for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )]
172+ for c in ["Q" , "R" , "both" ]
173+ ]
174+ + [f"shape={ s } , gradient_test_case=R, mode=r" for s in [(3 , 3 ), (6 , 3 ), (3 , 6 )]]
175+ + ["shape=(3, 3), gradient_test_case=Q, mode=raw" ]
176+ ),
177+ )
178+ def test_qr_grad (shape , gradient_test_case , mode ):
158179 rng = np .random .default_rng (utt .fetch_seed ())
159180
160- def _test_fn (x ):
161- return qr (x , mode = "reduced" )[output ]
181+ def _test_fn (x , case = 2 , mode = "reduced" ):
182+ if case == 0 :
183+ return qr (x , mode = mode )[0 ].sum ()
184+ elif case == 1 :
185+ return qr (x , mode = mode )[1 ].sum ()
186+ elif case == 2 :
187+ Q , R = qr (x , mode = mode )
188+ return Q .sum () + R .sum ()
162189
190+ m , n = shape
163191 a = rng .standard_normal (shape ).astype (config .floatX )
164- utt .verify_grad (_test_fn , [a ], rng = np .random )
192+
193+ if m < n or (mode == "complete" and m != n ) or mode == "raw" :
194+ with pytest .raises (NotImplementedError ):
195+ utt .verify_grad (
196+ partial (_test_fn , case = gradient_test_case , mode = mode ),
197+ [a ],
198+ rng = np .random ,
199+ )
200+
201+ else :
202+ utt .verify_grad (
203+ partial (_test_fn , case = gradient_test_case , mode = mode ), [a ], rng = np .random
204+ )
165205
166206
167207class TestSvd (utt .InferShapeTester ):
0 commit comments