Skip to content

Commit 6c1cf07

Browse files
committed
Handle case of complex fill val
1 parent 5c14ce6 commit 6c1cf07

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -762,14 +762,14 @@ def perform_replacements( # numpydoc ignore=PR01,RT01
762762
x = xp.where(idx_posinf, finfo.max, x)
763763
return xp.where(idx_neginf, finfo.min, x)
764764

765-
if xp.isdtype(x.dtype, "complex floating"):
765+
if isinstance(fill_value, complex) or xp.isdtype(x.dtype, "complex floating"):
766766
return perform_replacements(
767767
xp.real(x),
768-
fill_value,
768+
fill_value.real,
769769
xp,
770770
) + 1j * perform_replacements(
771771
xp.imag(x),
772-
fill_value,
772+
fill_value.imag,
773773
xp,
774774
)
775775

tests/test_funcs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -988,11 +988,25 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None:
988988
xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]),
989989
)
990990

991-
def test_fill_value(self, xp: ModuleType) -> None:
991+
@pytest.mark.parametrize("fill_value", [3, 3.0, 3 + 0j])
992+
@pytest.mark.parametrize(
993+
"output",
994+
[
995+
[1.0, 2.0, 3.0, 4.0],
996+
[1.0, 2.0, 3.0, 4.0],
997+
[1.0 + 0.j, 2.0 + 0.j, 3.0 + 0.j, 4.0 + 0.j]
998+
],
999+
)
1000+
def test_fill_value(
1001+
self,
1002+
xp: ModuleType,
1003+
fill_value: float,
1004+
output: float,
1005+
) -> None:
9921006
a = xp.asarray([1, 2, np.nan, 4])
9931007
xp_assert_equal(
994-
nan_to_num(a, fill_value=3, xp=xp),
995-
xp.asarray([1.0, 2.0, 3.0, 4.0]),
1008+
nan_to_num(a, fill_value=fill_value, xp=xp),
1009+
xp.asarray(output),
9961010
)
9971011

9981012
def test_empty_array(self, xp: ModuleType) -> None:

0 commit comments

Comments
 (0)