Skip to content

Commit 383f7f1

Browse files
authored
Merge branch 'master' into reciprocal
2 parents d649c0c + 5d890cd commit 383f7f1

File tree

3 files changed

+37
-47
lines changed

3 files changed

+37
-47
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,17 +370,23 @@ def mutually_broadcastable_shapes(
370370
# TODO: Add support for complex Hermitian matrices
371371
@composite
372372
def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.):
373+
# for now, only generate elements from (1, bound); TODO: restore
374+
# generating from (-bound, -1/bound).or.(1/bound, bound)
375+
# Note that using `assume` triggers a HealthCheck for filtering too much.
373376
shape = draw(square_matrix_shapes)
374377
dtype = draw(dtypes)
375378
if not isinstance(finite, bool):
376379
finite = draw(finite)
377-
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
380+
if finite:
381+
elements = {'allow_nan': False, 'allow_infinity': False,
382+
'min_value': 1, 'max_value': bound}
383+
else:
384+
elements = None
378385
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
379386
at = ah._matrix_transpose(a)
380387
H = (a + at)*0.5
381388
if finite:
382389
assume(not xp.any(xp.isinf(H)))
383-
assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound))))
384390
return H
385391

386392
@composite

array_api_tests/test_data_type_functions.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,31 +129,17 @@ def test_broadcast_to(x, data):
129129
# TODO: test values
130130

131131

132-
@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data())
133-
def test_can_cast(_from, to, data):
134-
from_ = data.draw(
135-
st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_"
136-
)
132+
@given(_from=hh.all_dtypes, to=hh.all_dtypes)
133+
def test_can_cast(_from, to):
134+
out = xp.can_cast(_from, to)
137135

138-
out = xp.can_cast(from_, to)
136+
expected = False
137+
for other in dh.all_dtypes:
138+
if dh.promotion_table.get((_from, other)) == to:
139+
expected = True
140+
break
139141

140142
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
141-
assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}"
142-
if _from == xp.bool:
143-
expected = to == xp.bool
144-
else:
145-
same_family = None
146-
for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]:
147-
if _from in dtypes:
148-
same_family = to in dtypes
149-
break
150-
assert same_family is not None # sanity check
151-
if same_family:
152-
from_min, from_max = dh.dtype_ranges[_from]
153-
to_min, to_max = dh.dtype_ranges[to]
154-
expected = from_min >= to_min and from_max <= to_max
155-
else:
156-
expected = False
157143
if expected:
158144
# cross-kind casting is not explicitly disallowed. We can only test
159145
# the cases where it should return True. TODO: if expected=False,

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,14 +1061,13 @@ def refimpl(_x, _min, _max):
10611061
)
10621062

10631063

1064-
if api_version >= "2022.12":
1065-
1066-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1067-
def test_conj(x):
1068-
out = xp.conj(x)
1069-
ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype)
1070-
ph.assert_shape("conj", out_shape=out.shape, expected=x.shape)
1071-
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
1064+
@pytest.mark.min_version("2022.12")
1065+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1066+
def test_conj(x):
1067+
out = xp.conj(x)
1068+
ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype)
1069+
ph.assert_shape("conj", out_shape=out.shape, expected=x.shape)
1070+
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
10721071

10731072

10741073
@pytest.mark.min_version("2023.12")
@@ -1263,14 +1262,14 @@ def test_hypot(x1, x2):
12631262
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
12641263

12651264

1266-
if api_version >= "2022.12":
12671265

1268-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1269-
def test_imag(x):
1270-
out = xp.imag(x)
1271-
ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1272-
ph.assert_shape("imag", out_shape=out.shape, expected=x.shape)
1273-
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
1266+
@pytest.mark.min_version("2022.12")
1267+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1268+
def test_imag(x):
1269+
out = xp.imag(x)
1270+
ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1271+
ph.assert_shape("imag", out_shape=out.shape, expected=x.shape)
1272+
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
12741273

12751274

12761275
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
@@ -1575,14 +1574,13 @@ def test_reciprocal(x):
15751574
)
15761575

15771576

1578-
if api_version >= "2022.12":
1579-
1580-
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1581-
def test_real(x):
1582-
out = xp.real(x)
1583-
ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1584-
ph.assert_shape("real", out_shape=out.shape, expected=x.shape)
1585-
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
1577+
@pytest.mark.min_version("2022.12")
1578+
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
1579+
def test_real(x):
1580+
out = xp.real(x)
1581+
ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
1582+
ph.assert_shape("real", out_shape=out.shape, expected=x.shape)
1583+
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
15861584

15871585

15881586
@pytest.mark.skip(reason="flaky")

0 commit comments

Comments
 (0)