Skip to content

Commit 4c15018

Browse files
Make array tests verify shape first
1 parent b14c117 commit 4c15018

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/tests/test_array_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,12 @@ def save_expr(name, expr):
374374
def test_run_lda(fn_thunk, benchmark):
375375
fn = fn_thunk()
376376
# warmup once for numba
377-
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
377+
real_res = run_lda(X_np, y_np)
378+
379+
fn_res = fn(X_np, y_np)
380+
assert real_res.shape == fn_res.shape
381+
assert real_res.dtype == fn_res.dtype
382+
assert np.allclose(real_res, fn_res, rtol=1e-03)
378383
benchmark(fn, X_np, y_np)
379384

380385

0 commit comments

Comments
 (0)