Skip to content

Commit d2e491c

Browse files
committed
MAINT: signal: convert lp2{lp,hp, bp, bs} to be array api compatible
1 parent 89f8532 commit d2e491c

File tree

2 files changed

+145
-65
lines changed

2 files changed

+145
-65
lines changed

scipy/signal/_filter_design.py

Lines changed: 111 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from numpy import (atleast_1d, poly, polyval, roots, real, asarray,
8-
resize, pi, absolute, sqrt, tan, log10,
8+
pi, absolute, sqrt, tan, log10,
99
arcsinh, sin, exp, cosh, arccosh, ceil, conjugate,
1010
zeros, sinh, append, concatenate, prod, ones, full, array,
1111
mintypecode)
@@ -17,6 +17,9 @@
1717
from scipy._lib._util import float_factorial
1818
from scipy.signal._arraytools import _validate_fs
1919

20+
import scipy._lib.array_api_extra as xpx
21+
from scipy._lib._array_api import array_namespace, xp_promote, xp_size
22+
2023

2124
__all__ = ['findfreqs', 'freqs', 'freqz', 'tf2zpk', 'zpk2tf', 'normalize',
2225
'lp2lp', 'lp2hp', 'lp2bp', 'lp2bs', 'bilinear', 'iirdesign',
@@ -1676,7 +1679,7 @@ def idx_worst(p):
16761679
return sos
16771680

16781681

1679-
def _align_nums(nums):
1682+
def _align_nums(nums, xp):
16801683
"""Aligns the shapes of multiple numerators.
16811684
16821685
Given an array of numerator coefficient arrays [[a_1, a_2,...,
@@ -1701,19 +1704,19 @@ def _align_nums(nums):
17011704
# The statement can throw a ValueError if one
17021705
# of the numerators is a single digit and another
17031706
# is array-like e.g. if nums = [5, [1, 2, 3]]
1704-
nums = asarray(nums)
1707+
nums = xp.asarray(nums)
17051708

1706-
if not np.issubdtype(nums.dtype, np.number):
1709+
if not xp.isdtype(nums.dtype, "numeric"):
17071710
raise ValueError("dtype of numerator is non-numeric")
17081711

17091712
return nums
17101713

17111714
except ValueError:
1712-
nums = [np.atleast_1d(num) for num in nums]
1713-
max_width = max(num.size for num in nums)
1715+
nums = [xpx.atleast_nd(xp.asarray(num), ndim=1) for num in nums]
1716+
max_width = max(xp_size(num) for num in nums)
17141717

17151718
# pre-allocate
1716-
aligned_nums = np.zeros((len(nums), max_width))
1719+
aligned_nums = xp.zeros((nums.shape[0], max_width))
17171720

17181721
# Create numerators with padded zeros
17191722
for index, num in enumerate(nums):
@@ -1722,6 +1725,26 @@ def _align_nums(nums):
17221725
return aligned_nums
17231726

17241727

1728+
def _trim_zeros(filt, trim='fb'):
1729+
# https://github.com/numpy/numpy/blob/v2.1.0/numpy/lib/_function_base_impl.py#L1874-L1925
1730+
first = 0
1731+
trim = trim.upper()
1732+
if 'F' in trim:
1733+
for i in filt:
1734+
if i != 0.:
1735+
break
1736+
else:
1737+
first = first + 1
1738+
last = filt.shape[0]
1739+
if 'B' in trim:
1740+
for i in filt[::-1]:
1741+
if i != 0.:
1742+
break
1743+
else:
1744+
last = last - 1
1745+
return filt[first:last]
1746+
1747+
17251748
def normalize(b, a):
17261749
"""Normalize numerator/denominator of a continuous-time transfer function.
17271750
@@ -1778,30 +1801,33 @@ def normalize(b, a):
17781801
Badly conditioned filter coefficients (numerator): the results may be meaningless
17791802
17801803
"""
1781-
num, den = b, a
1804+
xp = array_namespace(b, a)
17821805

1783-
den = np.asarray(den)
1784-
den = np.atleast_1d(den)
1785-
num = np.atleast_2d(_align_nums(num))
1806+
den = xp.asarray(a)
1807+
den = xpx.atleast_nd(den, ndim=1, xp=xp)
1808+
1809+
num = xp.asarray(b)
1810+
num = xpx.atleast_nd(_align_nums(num, xp), ndim=2, xp=xp)
17861811

17871812
if den.ndim != 1:
17881813
raise ValueError("Denominator polynomial must be rank-1 array.")
17891814
if num.ndim > 2:
17901815
raise ValueError("Numerator polynomial must be rank-1 or"
17911816
" rank-2 array.")
1792-
if np.all(den == 0):
1817+
if xp.all(den == 0):
17931818
raise ValueError("Denominator must have at least on nonzero element.")
17941819

17951820
# Trim leading zeros in denominator, leave at least one.
1796-
den = np.trim_zeros(den, 'f')
1821+
den = _trim_zeros(den, 'f')
17971822

17981823
# Normalize transfer function
17991824
num, den = num / den[0], den / den[0]
18001825

18011826
# Count numerator columns that are all zero
18021827
leading_zeros = 0
1803-
for col in num.T:
1804-
if np.allclose(col, 0, atol=1e-14):
1828+
for j in range(num.shape[-1]):
1829+
col = num[:, j]
1830+
if xp.all(xp.abs(col) <= 1e-14):
18051831
leading_zeros += 1
18061832
else:
18071833
break
@@ -1879,22 +1905,49 @@ def lp2lp(b, a, wo=1.0):
18791905
>>> plt.legend()
18801906
18811907
"""
1882-
a, b = map(atleast_1d, (a, b))
1908+
xp = array_namespace(a, b)
1909+
a, b = map(xp.asarray, (a, b))
1910+
a, b = xp_promote(a, b, force_floating=True, xp=xp)
1911+
a = xpx.atleast_nd(a, ndim=1, xp=xp)
1912+
b = xpx.atleast_nd(b, ndim=1, xp=xp)
1913+
18831914
try:
18841915
wo = float(wo)
18851916
except TypeError:
18861917
wo = float(wo[0])
1887-
d = len(a)
1888-
n = len(b)
1918+
d = a.shape[0]
1919+
n = b.shape[0]
18891920
M = max((d, n))
1890-
pwo = pow(wo, np.arange(M - 1, -1, -1))
1921+
pwo = wo ** xp.arange(M - 1, -1, -1, dtype=xp.float64)
18911922
start1 = max((n - d, 0))
18921923
start2 = max((d - n, 0))
18931924
b = b * pwo[start1] / pwo[start2:]
18941925
a = a * pwo[start1] / pwo[start1:]
18951926
return normalize(b, a)
18961927

18971928

1929+
def _resize(a, new_shape, xp):
1930+
# https://github.com/numpy/numpy/blob/v2.2.4/numpy/_core/fromnumeric.py#L1535
1931+
a = xp.reshape(a, (-1,))
1932+
1933+
new_size = 1
1934+
for dim_length in new_shape:
1935+
new_size *= dim_length
1936+
if dim_length < 0:
1937+
raise ValueError(
1938+
'all elements of `new_shape` must be non-negative'
1939+
)
1940+
1941+
if xp_size(a) == 0 or new_size == 0:
1942+
# First case must zero fill. The second would have repeats == 0.
1943+
return xp.zeros_like(a, shape=new_shape)
1944+
1945+
repeats = -(-new_size // xp_size(a)) # ceil division
1946+
a = xp.concat((a,) * repeats)[:new_size]
1947+
1948+
return xp.reshape(a, new_shape)
1949+
1950+
18981951
def lp2hp(b, a, wo=1.0):
18991952
r"""
19001953
Transform a lowpass filter prototype to a highpass filter.
@@ -1953,27 +2006,34 @@ def lp2hp(b, a, wo=1.0):
19532006
>>> plt.legend()
19542007
19552008
"""
1956-
a, b = map(atleast_1d, (a, b))
2009+
xp = array_namespace(a, b)
2010+
2011+
a, b = map(xp.asarray, (a, b))
2012+
a, b = xp_promote(a, b, force_floating=True, xp=xp)
2013+
a = xpx.atleast_nd(a, ndim=1, xp=xp)
2014+
b = xpx.atleast_nd(b, ndim=1, xp=xp)
2015+
19572016
try:
19582017
wo = float(wo)
19592018
except TypeError:
19602019
wo = float(wo[0])
1961-
d = len(a)
1962-
n = len(b)
2020+
d = a.shape[0]
2021+
n = b.shape[0]
19632022
if wo != 1:
1964-
pwo = pow(wo, np.arange(max((d, n))))
2023+
pwo = wo ** xp.arange(max((d, n)), dtype=xp.float64)
19652024
else:
1966-
pwo = np.ones(max((d, n)), b.dtype.char)
2025+
pwo = xp.ones(max((d, n)), dtype=b.dtype)
19672026
if d >= n:
1968-
outa = a[::-1] * pwo
1969-
outb = resize(b, (d,))
2027+
outa = xp.flip(a) * pwo
2028+
outb = xp.concat((xp.zeros(n, dtype=b.dtype), ))
2029+
outb = _resize(b, (d,), xp=xp)
19702030
outb[n:] = 0.0
1971-
outb[:n] = b[::-1] * pwo[:n]
2031+
outb[:n] = xp.flip(b) * pwo[:n]
19722032
else:
1973-
outb = b[::-1] * pwo
1974-
outa = resize(a, (n,))
2033+
outb = xp.flip(b) * pwo
2034+
outa = _resize(a, (n,), xp=xp)
19752035
outa[d:] = 0.0
1976-
outa[:d] = a[::-1] * pwo[:d]
2036+
outa[:d] = xp.flip(a) * pwo[:d]
19772037

19782038
return normalize(outb, outa)
19792039

@@ -2038,16 +2098,20 @@ def lp2bp(b, a, wo=1.0, bw=1.0):
20382098
>>> plt.ylabel('Amplitude [dB]')
20392099
>>> plt.legend()
20402100
"""
2101+
xp = array_namespace(a, b)
2102+
2103+
a, b = map(xp.asarray, (a, b))
2104+
a, b = xp_promote(a, b, force_floating=True, xp=xp)
2105+
a = xpx.atleast_nd(a, ndim=1, xp=xp)
2106+
b = xpx.atleast_nd(b, ndim=1, xp=xp)
20412107

2042-
a, b = map(atleast_1d, (a, b))
2043-
D = len(a) - 1
2044-
N = len(b) - 1
2045-
artype = mintypecode((a, b))
2108+
D = a.shape[0] - 1
2109+
N = b.shape[0] - 1
20462110
ma = max([N, D])
20472111
Np = N + ma
20482112
Dp = D + ma
2049-
bprime = np.empty(Np + 1, artype)
2050-
aprime = np.empty(Dp + 1, artype)
2113+
bprime = xp.empty(Np + 1, dtype=b.dtype)
2114+
aprime = xp.empty(Dp + 1, dtype=a.dtype)
20512115
wosq = wo * wo
20522116
for j in range(Np + 1):
20532117
val = 0.0
@@ -2126,15 +2190,20 @@ def lp2bs(b, a, wo=1.0, bw=1.0):
21262190
>>> plt.ylabel('Amplitude [dB]')
21272191
>>> plt.legend()
21282192
"""
2129-
a, b = map(atleast_1d, (a, b))
2130-
D = len(a) - 1
2131-
N = len(b) - 1
2132-
artype = mintypecode((a, b))
2193+
xp = array_namespace(a, b)
2194+
2195+
a, b = map(xp.asarray, (a, b))
2196+
a, b = xp_promote(a, b, force_floating=True, xp=xp)
2197+
a = xpx.atleast_nd(a, ndim=1, xp=xp)
2198+
b = xpx.atleast_nd(b, ndim=1, xp=xp)
2199+
2200+
D = a.shape[0] - 1
2201+
N = b.shape[0] - 1
21332202
M = max([N, D])
21342203
Np = M + M
21352204
Dp = M + M
2136-
bprime = np.empty(Np + 1, artype)
2137-
aprime = np.empty(Dp + 1, artype)
2205+
bprime = xp.empty(Np + 1, dtype=b.dtype)
2206+
aprime = xp.empty(Dp + 1, dtype=a.dtype)
21382207
wosq = wo * wo
21392208
for j in range(Np + 1):
21402209
val = 0.0

scipy/signal/tests/test_filter_design.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import warnings
23

34
from itertools import product
@@ -27,6 +28,9 @@
2728
from scipy.signal._filter_design import (_cplxreal, _cplxpair, _norm_factor,
2829
_bessel_poly, _bessel_zeros)
2930

31+
skip_xp_backends = pytest.mark.skip_xp_backends
32+
xfail_xp_backends = pytest.mark.xfail_xp_backends
33+
3034

3135
try:
3236
import mpmath
@@ -1351,43 +1355,50 @@ def test_errors(self):
13511355

13521356
class TestLp2lp:
13531357

1354-
def test_basic(self):
1355-
b = [1]
1356-
a = [1, np.sqrt(2), 1]
1358+
def test_basic(self, xp):
1359+
b = xp.asarray([1])
1360+
a = xp.asarray([1, math.sqrt(2), 1])
13571361
b_lp, a_lp = lp2lp(b, a, 0.38574256627112119)
1358-
assert_array_almost_equal(b_lp, [0.1488], decimal=4)
1359-
assert_array_almost_equal(a_lp, [1, 0.5455, 0.1488], decimal=4)
1362+
assert_array_almost_equal(b_lp, xp.asarray([0.1488]), decimal=4)
1363+
assert_array_almost_equal(a_lp, xp.asarray([1, 0.5455, 0.1488]), decimal=4)
13601364

13611365

13621366
class TestLp2hp:
13631367

1364-
def test_basic(self):
1365-
b = [0.25059432325190018]
1366-
a = [1, 0.59724041654134863, 0.92834805757524175, 0.25059432325190018]
1367-
b_hp, a_hp = lp2hp(b, a, 2*np.pi*5000)
1368-
xp_assert_close(b_hp, [1.0, 0, 0, 0])
1369-
xp_assert_close(a_hp, [1, 1.1638e5, 2.3522e9, 1.2373e14], rtol=1e-4)
1368+
def test_basic(self, xp):
1369+
b = xp.asarray([0.25059432325190018])
1370+
a = xp.asarray(
1371+
[1, 0.59724041654134863, 0.92834805757524175, 0.25059432325190018]
1372+
)
1373+
b_hp, a_hp = lp2hp(b, a, 2*math.pi*5000)
1374+
xp_assert_close(b_hp, xp.asarray([1.0, 0, 0, 0]))
1375+
xp_assert_close(
1376+
a_hp, xp.asarray([1, 1.1638e5, 2.3522e9, 1.2373e14]), rtol=1e-4
1377+
)
13701378

13711379

13721380
class TestLp2bp:
13731381

1374-
def test_basic(self):
1375-
b = [1]
1376-
a = [1, 2, 2, 1]
1377-
b_bp, a_bp = lp2bp(b, a, 2*np.pi*4000, 2*np.pi*2000)
1378-
xp_assert_close(b_bp, [1.9844e12, 0, 0, 0], rtol=1e-6)
1379-
xp_assert_close(a_bp, [1, 2.5133e4, 2.2108e9, 3.3735e13,
1380-
1.3965e18, 1.0028e22, 2.5202e26], rtol=1e-4)
1382+
def test_basic(self, xp):
1383+
b = xp.asarray([1])
1384+
a = xp.asarray([1, 2, 2, 1])
1385+
b_bp, a_bp = lp2bp(b, a, 2*math.pi*4000, 2*math.pi*2000)
1386+
xp_assert_close(b_bp, xp.asarray([1.9844e12, 0, 0, 0]), rtol=1e-6)
1387+
xp_assert_close(
1388+
a_bp,
1389+
xp.asarray([1, 2.5133e4, 2.2108e9, 3.3735e13,
1390+
1.3965e18, 1.0028e22, 2.5202e26]), rtol=1e-4
1391+
)
13811392

13821393

13831394
class TestLp2bs:
13841395

1385-
def test_basic(self):
1386-
b = [1]
1387-
a = [1, 1]
1396+
def test_basic(self, xp):
1397+
b = xp.asarray([1])
1398+
a = xp.asarray([1, 1])
13881399
b_bs, a_bs = lp2bs(b, a, 0.41722257286366754, 0.18460575326152251)
1389-
assert_array_almost_equal(b_bs, [1, 0, 0.17407], decimal=5)
1390-
assert_array_almost_equal(a_bs, [1, 0.18461, 0.17407], decimal=5)
1400+
assert_array_almost_equal(b_bs, xp.asarray([1, 0, 0.17407]), decimal=5)
1401+
assert_array_almost_equal(a_bs, xp.asarray([1, 0.18461, 0.17407]), decimal=5)
13911402

13921403

13931404
class TestBilinear:

0 commit comments

Comments
 (0)