|
24 | 24 | "assert_shape",
|
25 | 25 | "assert_result_shape",
|
26 | 26 | "assert_keepdimable_shape",
|
| 27 | + "assert_0d_equals", |
27 | 28 | "assert_fill",
|
28 | 29 | "assert_array",
|
29 | 30 | ]
|
@@ -242,15 +243,28 @@ def assert_fill(
|
242 | 243 | def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
|
243 | 244 | assert_dtype(func_name, out.dtype, expected.dtype)
|
244 | 245 | assert_shape(func_name, out.shape, expected.shape, **kw)
|
245 |
| - msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}" |
| 246 | + f_func = f"[{func_name}({fmt_kw(kw)})]" |
246 | 247 | if dh.is_float_dtype(out.dtype):
|
247 |
| - neg_zeros = expected == -0.0 |
248 |
| - assert xp.all((out == -0.0) == neg_zeros), msg |
249 |
| - pos_zeros = expected == +0.0 |
250 |
| - assert xp.all((out == +0.0) == pos_zeros), msg |
251 |
| - nans = xp.isnan(expected) |
252 |
| - assert xp.all(xp.isnan(out) == nans), msg |
253 |
| - mask = ~(neg_zeros | pos_zeros | nans) |
254 |
| - assert xp.all(out[mask] == expected[mask]), msg |
| 248 | + for idx in sh.ndindex(out.shape): |
| 249 | + at_out = out[idx] |
| 250 | + at_expected = expected[idx] |
| 251 | + msg = ( |
| 252 | + f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " |
| 253 | + f"{f_func}" |
| 254 | + ) |
| 255 | + if xp.isnan(at_expected): |
| 256 | + assert xp.isnan(at_out), msg |
| 257 | + elif at_expected == 0.0 or at_expected == -0.0: |
| 258 | + scalar_at_expected = float(at_expected) |
| 259 | + scalar_at_out = float(at_out) |
| 260 | + if is_pos_zero(scalar_at_expected): |
| 261 | + assert is_pos_zero(scalar_at_out), msg |
| 262 | + else: |
| 263 | + assert is_neg_zero(scalar_at_expected) # sanity check |
| 264 | + assert is_neg_zero(scalar_at_out), msg |
| 265 | + else: |
| 266 | + assert at_out == at_expected, msg |
255 | 267 | else:
|
256 |
| - assert xp.all(out == expected), msg |
| 268 | + assert xp.all(out == expected), ( |
| 269 | + f"out not as expected {f_func}\n" f"{out=}\n{expected=}" |
| 270 | + ) |
0 commit comments