Skip to content

Commit ef3287c

Browse files
authored
Merge branch 'master' into test_pow_with_scalars
2 parents e4d2337 + 188b1e9 commit ef3287c

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,63 @@ def test_take(x, data):
6060
# sanity check
6161
with pytest.raises(StopIteration):
6262
next(out_indices)
63+
64+
65+
66+
@pytest.mark.unvectorized
67+
@pytest.mark.min_version("2024.12")
68+
@given(
69+
x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
70+
data=st.data(),
71+
)
72+
def test_take_along_axis(x, data):
73+
# TODO
74+
# 2. negative indices
75+
# 3. different dtypes for indices
76+
# 4. "broadcast-compatible" indices
77+
axis = data.draw(
78+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
79+
label="axis"
80+
)
81+
if axis is None:
82+
axis_kw = {}
83+
n_axis = x.ndim - 1
84+
else:
85+
axis_kw = {"axis": axis}
86+
n_axis = axis + x.ndim if axis < 0 else axis
87+
88+
new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len")
89+
idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:]
90+
indices = data.draw(
91+
hh.arrays(
92+
shape=idx_shape,
93+
dtype=dh.default_int,
94+
elements={"min_value": 0, "max_value": x.shape[n_axis]-1}
95+
),
96+
label="indices"
97+
)
98+
note(f"{indices=} {idx_shape=}")
99+
100+
out = xp.take_along_axis(x, indices, **axis_kw)
101+
102+
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
103+
ph.assert_shape(
104+
"take_along_axis",
105+
out_shape=out.shape,
106+
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
107+
kw=dict(
108+
x=x,
109+
indices=indices,
110+
axis=axis,
111+
),
112+
)
113+
114+
# value test: notation is from `np.take_along_axis` docstring
115+
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
116+
for ii in sh.ndindex(Ni):
117+
for kk in sh.ndindex(Nk):
118+
a_1d = x[ii + (slice(None),) + kk]
119+
i_1d = indices[ii + (slice(None),) + kk]
120+
o_1d = out[ii + (slice(None),) + kk]
121+
for j in range(new_len):
122+
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,26 @@ def test_not_equal(ctx, data):
15971597
)
15981598

15991599

1600+
@pytest.mark.min_version("2024.12")
1601+
@given(
1602+
shapes=hh.two_mutually_broadcastable_shapes,
1603+
dtype=hh.real_floating_dtypes,
1604+
data=st.data()
1605+
)
1606+
def test_nextafter(shapes, dtype, data):
1607+
x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1")
1608+
x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2")
1609+
1610+
out = xp.nextafter(x1, x2)
1611+
_assert_correctness_binary(
1612+
"nextafter",
1613+
math.nextafter,
1614+
in_dtypes=[x1.dtype, x2.dtype],
1615+
in_shapes=[x1.shape, x2.shape],
1616+
in_arrs=[x1, x2],
1617+
out=out
1618+
)
1619+
16001620
@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
16011621
@given(data=st.data())
16021622
def test_positive(ctx, data):
@@ -1815,6 +1835,7 @@ def _filter_zero(x):
18151835
("divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
18161836
("hypot", math.hypot, {}, None),
18171837
("logaddexp", logaddexp_refimpl, {}, None),
1838+
("nextafter", math.nextafter, {}, None),
18181839
("maximum", max, {'strict_check': True}, None),
18191840
("minimum", min, {'strict_check': True}, None),
18201841
("multiply", operator.mul, {}, None),
@@ -1900,3 +1921,34 @@ def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
19001921
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
19011922
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
19021923

1924+
1925+
@pytest.mark.unvectorized
1926+
@given(
1927+
x1x2=hh.array_and_py_scalar([xp.int32]),
1928+
data=st.data()
1929+
)
1930+
def test_where_with_scalars(x1x2, data):
1931+
x1, x2 = x1x2
1932+
1933+
if dh.is_scalar(x1):
1934+
dtype, shape = x2.dtype, x2.shape
1935+
x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2
1936+
else:
1937+
dtype, shape = x1.dtype, x1.shape
1938+
x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape)
1939+
1940+
condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool))
1941+
1942+
out = xp.where(condition, x1, x2)
1943+
1944+
assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}"
1945+
assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}"
1946+
1947+
# value test
1948+
for idx in sh.ndindex(shape):
1949+
if condition[idx]:
1950+
assert out[idx] == x1_arr[idx]
1951+
else:
1952+
assert out[idx] == x2_arr[idx]
1953+
1954+

array_api_tests/test_utility_functions.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,79 @@ def test_any(x, data):
6363
expected = any(elements)
6464
ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
6565
out=result, expected=expected, kw=kw)
66+
67+
68+
@pytest.mark.unvectorized
69+
@pytest.mark.min_version("2024.12")
70+
@given(
71+
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
72+
data=st.data(),
73+
)
74+
def test_diff(x, data):
75+
axis = data.draw(
76+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
77+
label="axis"
78+
)
79+
if axis is None:
80+
axis_kw = {"axis": -1}
81+
n_axis = x.ndim - 1
82+
else:
83+
axis_kw = {"axis": axis}
84+
n_axis = axis + x.ndim if axis < 0 else axis
85+
86+
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
87+
88+
out = xp.diff(x, **axis_kw, n=n)
89+
90+
expected_shape = list(x.shape)
91+
expected_shape[n_axis] -= n
92+
93+
assert out.shape == tuple(expected_shape)
94+
95+
# value test
96+
if n == 1:
97+
for idx in sh.ndindex(out.shape):
98+
l = list(idx)
99+
l[n_axis] += 1
100+
assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
101+
102+
103+
@pytest.mark.min_version("2024.12")
104+
@pytest.mark.unvectorized
105+
@given(
106+
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
107+
data=st.data(),
108+
)
109+
def test_diff_append_prepend(x, data):
110+
axis = data.draw(
111+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
112+
label="axis"
113+
)
114+
if axis is None:
115+
axis_kw = {"axis": -1}
116+
n_axis = x.ndim - 1
117+
else:
118+
axis_kw = {"axis": axis}
119+
n_axis = axis + x.ndim if axis < 0 else axis
120+
121+
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
122+
123+
append_shape = list(x.shape)
124+
append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis")
125+
append_shape[n_axis] = append_axis_len
126+
append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append")
127+
128+
prepend_shape = list(x.shape)
129+
prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis")
130+
prepend_shape[n_axis] = prepend_axis_len
131+
prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend")
132+
133+
out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
134+
135+
in_1 = xp.concat((prepend, x, append), **axis_kw)
136+
out_1 = xp.diff(in_1, **axis_kw, n=n)
137+
138+
assert out.shape == out_1.shape
139+
for idx in sh.ndindex(out.shape):
140+
assert out[idx] == out_1[idx], f"{idx = }"
141+

0 commit comments

Comments
 (0)