Skip to content

Commit a6ae03b

Browse files
author
Etienne Duchesne
committed
qr decomposition gradient: add xfail pytest for complex inputs
1 parent ac48c11 commit a6ae03b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/tensor/test_nlinalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def test_qr_modes():
175175
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
176176
),
177177
)
178-
def test_qr_grad(shape, gradient_test_case, mode):
178+
@pytest.mark.parametrize("is_complex", [True, False], ["complex", "real"])
179+
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
179180
rng = np.random.default_rng(utt.fetch_seed())
180181

181182
def _test_fn(x, case=2, mode="reduced"):
@@ -187,8 +188,13 @@ def _test_fn(x, case=2, mode="reduced"):
187188
Q, R = qr(x, mode=mode)
188189
return Q.sum() + R.sum()
189190

191+
if is_complex:
192+
pytest.xfail("Complex inputs currently not supported by verify_grad")
193+
190194
m, n = shape
191195
a = rng.standard_normal(shape).astype(config.floatX)
196+
if is_complex:
197+
a += 1j * rng.standard_normal(shape).astype(config.floatX)
192198

193199
if mode == "raw":
194200
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)