Skip to content

Commit 8cf4a84

Browse files
committed
TST: Test copy behavior in astype
1 parent f7a74a6 commit 8cf4a84

File tree

3 files changed

+51
-30
lines changed

3 files changed

+51
-30
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
from inspect import getfullargspec
44
from typing import Any, Dict, Optional, Sequence, Tuple, Union
55

6-
from . import _array_module as xp
6+
from hypothesis import note
7+
8+
from . import _array_module as xp, xps
79
from . import dtype_helpers as dh
810
from . import shape_helpers as sh
11+
from . import hypothesis_helpers as hh
912
from . import stubs
1013
from . import xp as _xp
1114
from .typing import Array, DataType, Scalar, ScalarType, Shape
@@ -28,6 +31,7 @@
2831
"assert_0d_equals",
2932
"assert_fill",
3033
"assert_array_elements",
34+
"assert_kw_copy"
3135
]
3236

3337

@@ -483,6 +487,48 @@ def assert_fill(
483487
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
484488

485489

490+
def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
491+
if cmath.isnan(s1):
492+
return cmath.isnan(s2)
493+
else:
494+
return s1 == s2
495+
496+
497+
def assert_kw_copy(func_name, x, out, data, copy):
498+
"""
499+
Assert copy=True/False functionality is respected
500+
501+
TODO: we're not able to check scalars with this approach
502+
"""
503+
if copy is not None and len(x.shape) > 0:
504+
stype = dh.get_scalar_type(x.dtype)
505+
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
506+
old_value = stype(x[idx])
507+
scalar_strat = hh.from_dtype(x.dtype).filter(
508+
lambda n: not scalar_eq(n, old_value)
509+
)
510+
value = data.draw(
511+
scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)),
512+
label="mutating value",
513+
)
514+
x[idx] = value
515+
note(f"mutated {x=}")
516+
# sanity check
517+
assert_scalar_equals(
518+
"__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x"
519+
)
520+
new_out_value = stype(out[idx])
521+
f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}"
522+
if copy:
523+
assert scalar_eq(
524+
new_out_value, old_value
525+
), f"{f_out}, but should be {old_value} even after x was mutated"
526+
else:
527+
assert scalar_eq(
528+
new_out_value, value
529+
), f"{f_out}, but should be {value} after x was mutated"
530+
531+
486532
def _has_functional_signbit() -> bool:
487533
# signbit can be available but not implemented (e.g., in array-api-strict)
488534
if not hasattr(_xp, "signbit"):

array_api_tests/test_creation_functions.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -282,34 +282,7 @@ def test_asarray_arrays(shape, dtypes, data):
282282
ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype)
283283
ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape)
284284
ph.assert_array_elements("asarray", out=out, expected=x, kw=kw)
285-
copy = kw.get("copy", None)
286-
if copy is not None:
287-
stype = dh.get_scalar_type(x.dtype)
288-
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
289-
old_value = stype(x[idx])
290-
scalar_strat = hh.from_dtype(dtypes.input_dtype).filter(
291-
lambda n: not scalar_eq(n, old_value)
292-
)
293-
value = data.draw(
294-
scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)),
295-
label="mutating value",
296-
)
297-
x[idx] = value
298-
note(f"mutated {x=}")
299-
# sanity check
300-
ph.assert_scalar_equals(
301-
"__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x"
302-
)
303-
new_out_value = stype(out[idx])
304-
f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}"
305-
if copy:
306-
assert scalar_eq(
307-
new_out_value, old_value
308-
), f"{f_out}, but should be {old_value} even after x was mutated"
309-
else:
310-
assert scalar_eq(
311-
new_out_value, value
312-
), f"{f_out}, but should be {value} after x was mutated"
285+
ph.assert_kw_copy("asarray", x, out, data, kw.get("copy", None))
313286

314287

315288
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes))

array_api_tests/test_data_type_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def test_astype(x_dtype, dtype, kw, data):
8181
ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
8282
ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw)
8383
# TODO: test values
84-
# TODO: test copy
84+
# Check copy is respected (only if input dtype is same as output dtype)
85+
if dtype == x_dtype:
86+
ph.assert_kw_copy("astype", x, out, data, kw.get("copy", None))
8587

8688

8789
@given(

0 commit comments

Comments
 (0)