Skip to content

Commit 5074a28

Browse files
committed
MAINT: signal: bilinear_zpk array API
1 parent e12326b commit 5074a28

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

scipy/signal/_filter_design.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,8 +2863,11 @@ def bilinear_zpk(z, p, k, fs):
28632863
>>> plt.ylabel('Amplitude [dB]')
28642864
>>> plt.grid(True)
28652865
"""
2866-
z = atleast_1d(z)
2867-
p = atleast_1d(p)
2866+
xp = array_namespace(z, p)
2867+
2868+
z, p = map(xp.asarray, (z, p))
2869+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
2870+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
28682871

28692872
fs = _validate_fs(fs, allow_none=False)
28702873

@@ -2877,10 +2880,10 @@ def bilinear_zpk(z, p, k, fs):
28772880
p_z = (fs2 + p) / (fs2 - p)
28782881

28792882
# Any zeros that were at infinity get moved to the Nyquist frequency
2880-
z_z = append(z_z, -ones(degree))
2883+
z_z = xp.concat((z_z, -xp.ones(degree)))
28812884

28822885
# Compensate for gain change
2883-
k_z = k * real(prod(fs2 - z) / prod(fs2 - p))
2886+
k_z = k * xp.real(xp.prod(fs2 - z) / xp.prod(fs2 - p))
28842887

28852888
return z_z, p_z, k_z
28862889

scipy/signal/tests/test_filter_design.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,19 +1614,25 @@ def test_basic(self, xp):
16141614

16151615
class TestBilinear_zpk:
16161616

1617-
def test_basic(self):
1618-
z = [-2j, +2j]
1619-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1617+
def test_basic(self, xp):
1618+
z = xp.asarray([-2j, +2j])
1619+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
16201620
k = 3
16211621

16221622
z_d, p_d, k_d = bilinear_zpk(z, p, k, 10)
16231623

1624-
xp_assert_close(sort(z_d), sort([(20-2j)/(20+2j), (20+2j)/(20-2j),
1625-
-1]))
1626-
xp_assert_close(sort(p_d), sort([77/83,
1627-
(1j/2 + 39/2) / (41/2 - 1j/2),
1628-
(39/2 - 1j/2) / (1j/2 + 41/2), ]))
1629-
xp_assert_close(k_d, 9696/69803)
1624+
xp_assert_close(
1625+
_sort_cmplx(z_d, xp=xp),
1626+
_sort_cmplx([(20-2j) / (20+2j), (20+2j) / (20-2j), -1], xp=xp)
1627+
)
1628+
xp_assert_close(
1629+
_sort_cmplx(p_d, xp=xp),
1630+
_sort_cmplx(
1631+
[77/83, (1j/2 + 39/2) / (41/2 - 1j/2), (39/2 - 1j/2) / (1j/2 + 41/2)],
1632+
xp=xp
1633+
)
1634+
)
1635+
assert math.isclose(k_d, 9696/69803)
16301636

16311637

16321638
class TestPrototypeType:

0 commit comments

Comments
 (0)