Skip to content

Commit e29b9a4

Browse files
authored
Merge pull request scipy#21713 from ev-br/sigtools_convolve_cupy_lfilt
* ENH: signal: array API support / delegation in lfilter-like functions Note: Some CuPy functions deliberately deviate from scipy: {lfilter,sosfilt}_zi return arrays of different shapes. Skip dispatching for these cases; a user will have to be explicit on whether they want the 'original' SciPy output (then call lfilter_zi with NumPy array arguments) or the CuPy output (then call cupyx.scipy.signal.lfilter_zi with CuPy array arguments). * BUG: signal: restore (and test for) array-like list inputs The replacement for `np.atleast_1d(x)` is `xpx.atleast_nd(xp.asarray(x), ndim=1, xp=xp)`, no less. * MAINT: signal: link to a cupy fix * address review comments
2 parents 7f03fba + 40a626e commit e29b9a4

File tree

3 files changed

+673
-418
lines changed

3 files changed

+673
-418
lines changed

scipy/signal/_signaltools.py

Lines changed: 120 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828

2929
from scipy._lib._array_api import (
3030
array_namespace, is_torch, is_numpy, xp_copy, xp_size
31+
3132
)
3233
import scipy._lib.array_api_compat.numpy as np_compat
3334
import scipy._lib.array_api_extra as xpx
3435

36+
3537
__all__ = ['correlate', 'correlation_lags', 'correlate2d',
3638
'convolve', 'convolve2d', 'fftconvolve', 'oaconvolve',
3739
'order_filter', 'medfilt', 'medfilt2d', 'wiener', 'lfilter',
@@ -2192,12 +2194,22 @@ def lfilter(b, a, x, axis=-1, zi=None):
21922194
>>> plt.show()
21932195
21942196
"""
2197+
try:
2198+
xp = array_namespace(b, a, x, zi)
2199+
except TypeError:
2200+
# either in1 or in2 are object arrays
2201+
xp = np_compat
2202+
2203+
if is_numpy(xp):
2204+
_reject_objects(x, 'lfilter')
2205+
_reject_objects(a, 'lfilter')
2206+
_reject_objects(b, 'lfilter')
2207+
21952208
b = np.atleast_1d(b)
21962209
a = np.atleast_1d(a)
2197-
2198-
_reject_objects(x, 'lfilter')
2199-
_reject_objects(a, 'lfilter')
2200-
_reject_objects(b, 'lfilter')
2210+
x = np.asarray(x)
2211+
if zi is not None:
2212+
zi = np.asarray(zi)
22012213

22022214
if len(a) == 1:
22032215
# This path only supports types fdgFDGO to mirror _linear_filter below.
@@ -2256,16 +2268,18 @@ def lfilter(b, a, x, axis=-1, zi=None):
22562268
out = out_full[tuple(ind)]
22572269

22582270
if zi is None:
2259-
return out
2271+
return xp.asarray(out)
22602272
else:
22612273
ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None)
22622274
zf = out_full[tuple(ind)]
2263-
return out, zf
2275+
return xp.asarray(out), xp.asarray(zf)
22642276
else:
22652277
if zi is None:
2266-
return _sigtools._linear_filter(b, a, x, axis)
2278+
result =_sigtools._linear_filter(b, a, x, axis)
2279+
return xp.asarray(result)
22672280
else:
2268-
return _sigtools._linear_filter(b, a, x, axis, zi)
2281+
out, zf = _sigtools._linear_filter(b, a, x, axis, zi)
2282+
return xp.asarray(out), xp.asarray(zf)
22692283

22702284

22712285
def lfiltic(b, a, y, x=None):
@@ -2308,40 +2322,59 @@ def lfiltic(b, a, y, x=None):
23082322
lfilter, lfilter_zi
23092323
23102324
"""
2311-
N = np.size(a) - 1
2312-
M = np.size(b) - 1
2325+
try:
2326+
xp = array_namespace(a, b, y, x)
2327+
except TypeError:
2328+
xp = np_compat
2329+
2330+
if is_numpy(xp):
2331+
_reject_objects(a, 'lfiltic')
2332+
_reject_objects(b, 'lfiltic')
2333+
_reject_objects(y, 'lfiltic')
2334+
if x is not None:
2335+
_reject_objects(x, 'lfiltic')
2336+
2337+
a = xp.asarray(a)
2338+
b = xp.asarray(b)
2339+
2340+
N = xp_size(a) - 1
2341+
M = xp_size(b) - 1
23132342
K = max(M, N)
2314-
y = np.asarray(y)
2343+
y = xp.asarray(y)
23152344

23162345
if x is None:
2317-
result_type = np.result_type(np.asarray(b), np.asarray(a), y)
2318-
if result_type.kind in 'bui':
2319-
result_type = np.float64
2320-
x = np.zeros(M, dtype=result_type)
2346+
result_type = xp.result_type(b, a, y)
2347+
if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
2348+
result_type = xp.float64
2349+
x = xp.zeros(M, dtype=result_type)
23212350
else:
2322-
x = np.asarray(x)
2351+
x = xp.asarray(x)
23232352

2324-
result_type = np.result_type(np.asarray(b), np.asarray(a), y, x)
2325-
if result_type.kind in 'bui':
2326-
result_type = np.float64
2327-
x = x.astype(result_type)
2353+
result_type = xp.result_type(b, a, y, x)
2354+
if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
2355+
result_type = xp.float64
2356+
x = xp.astype(x, result_type)
23282357

2329-
L = np.size(x)
2358+
concat = array_namespace(a).concat
2359+
2360+
L = xp_size(x)
23302361
if L < M:
2331-
x = np.r_[x, np.zeros(M - L)]
2362+
x = concat((x, xp.zeros(M - L)))
2363+
2364+
y = xp.astype(y, result_type)
2365+
zi = xp.zeros(K, dtype=result_type)
23322366

2333-
y = y.astype(result_type)
2334-
zi = np.zeros(K, result_type)
2367+
concat = array_namespace(xp.ones(3)).concat
23352368

2336-
L = np.size(y)
2369+
L = xp_size(y)
23372370
if L < N:
2338-
y = np.r_[y, np.zeros(N - L)]
2371+
y = concat((y, np.zeros(N - L)))
23392372

23402373
for m in range(M):
2341-
zi[m] = np.sum(b[m + 1:] * x[:M - m], axis=0)
2374+
zi[m] = xp.sum(b[m + 1:] * x[:M - m], axis=0)
23422375

23432376
for m in range(N):
2344-
zi[m] -= np.sum(a[m + 1:] * y[:N - m], axis=0)
2377+
zi[m] -= xp.sum(a[m + 1:] * y[:N - m], axis=0)
23452378

23462379
return zi
23472380

@@ -2387,19 +2420,21 @@ def deconvolve(signal, divisor):
23872420
array([ 0., 1., 0., 0., 1., 1., 0., 0.])
23882421
23892422
"""
2390-
num = np.atleast_1d(signal)
2391-
den = np.atleast_1d(divisor)
2423+
xp = array_namespace(signal, divisor)
2424+
2425+
num = xpx.atleast_nd(xp.asarray(signal), ndim=1, xp=xp)
2426+
den = xpx.atleast_nd(xp.asarray(divisor), ndim=1, xp=xp)
23922427
if num.ndim > 1:
23932428
raise ValueError("signal must be 1-D.")
23942429
if den.ndim > 1:
23952430
raise ValueError("divisor must be 1-D.")
2396-
N = len(num)
2397-
D = len(den)
2431+
N = num.shape[0]
2432+
D = den.shape[0]
23982433
if D > N:
23992434
quot = []
24002435
rem = num
24012436
else:
2402-
input = np.zeros(N - D + 1, float)
2437+
input = xp.zeros(N - D + 1, dtype=xp.float64)
24032438
input[0] = 1
24042439
quot = lfilter(num, den, input)
24052440
rem = num - convolve(den, quot, mode='full')
@@ -2550,8 +2585,7 @@ def hilbert2(x, N=None):
25502585
25512586
"""
25522587
xp = array_namespace(x)
2553-
2554-
x = xpx.atleast_nd(x, ndim=2, xp=xp)
2588+
x = xpx.atleast_nd(xp.asarray(x), ndim=2, xp=xp)
25552589
if x.ndim > 2:
25562590
raise ValueError("x must be 2-D.")
25572591
if xp.isdtype(x.dtype, 'complex floating'):
@@ -4147,6 +4181,7 @@ def lfilter_zi(b, a):
41474181
transient until the input drops from 0.5 to 0.0.
41484182
41494183
"""
4184+
xp = array_namespace(b, a)
41504185

41514186
# FIXME: Can this function be replaced with an appropriate
41524187
# use of lfiltic? For example, when b,a = butter(N,Wn),
@@ -4156,35 +4191,37 @@ def lfilter_zi(b, a):
41564191
# We could use scipy.signal.normalize, but it uses warnings in
41574192
# cases where a ValueError is more appropriate, and it allows
41584193
# b to be 2D.
4159-
b = np.atleast_1d(b)
4194+
b = xpx.atleast_nd(xp.asarray(b), ndim=1, xp=xp)
41604195
if b.ndim != 1:
41614196
raise ValueError("Numerator b must be 1-D.")
4162-
a = np.atleast_1d(a)
4197+
a = xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp)
41634198
if a.ndim != 1:
41644199
raise ValueError("Denominator a must be 1-D.")
41654200

4166-
while len(a) > 1 and a[0] == 0.0:
4201+
while a.shape[0] > 1 and a[0] == 0.0:
41674202
a = a[1:]
4168-
if a.size < 1:
4203+
if xp_size(a) < 1:
41694204
raise ValueError("There must be at least one nonzero `a` coefficient.")
41704205

41714206
if a[0] != 1.0:
41724207
# Normalize the coefficients so a[0] == 1.
41734208
b = b / a[0]
41744209
a = a / a[0]
41754210

4176-
n = max(len(a), len(b))
4211+
n = max(a.shape[0], b.shape[0])
41774212

41784213
# Pad a or b with zeros so they are the same length.
4179-
if len(a) < n:
4180-
a = np.r_[a, np.zeros(n - len(a), dtype=a.dtype)]
4181-
elif len(b) < n:
4182-
b = np.r_[b, np.zeros(n - len(b), dtype=b.dtype)]
4183-
4184-
IminusA = np.eye(n - 1, dtype=np.result_type(a, b)) - linalg.companion(a).T
4214+
if a.shape[0] < n:
4215+
a = xp.concat((a, xp.zeros(n - a.shape[0], dtype=a.dtype)))
4216+
elif b.shape[0] < n:
4217+
b = xp.concat((b, xp.zeros(n - b.shape[0], dtype=b.dtype)))
4218+
4219+
dt = xp.result_type(a, b)
4220+
IminusA = np.eye(n - 1) - linalg.companion(a).T
4221+
IminusA = xp.asarray(IminusA, dtype=dt)
41854222
B = b[1:] - a[1:] * b[0]
41864223
# Solve zi = A*zi + B
4187-
zi = np.linalg.solve(IminusA, B)
4224+
zi = xp.linalg.solve(IminusA, B)
41884225

41894226
# For future reference: we could also use the following
41904227
# explicit formulas to solve the linear system:
@@ -4255,24 +4292,26 @@ def sosfilt_zi(sos):
42554292
>>> plt.show()
42564293
42574294
"""
4258-
sos = np.asarray(sos)
4295+
xp = array_namespace(sos)
4296+
4297+
sos = xp.asarray(sos)
42594298
if sos.ndim != 2 or sos.shape[1] != 6:
42604299
raise ValueError('sos must be shape (n_sections, 6)')
42614300

4262-
if sos.dtype.kind in 'bui':
4263-
sos = sos.astype(np.float64)
4301+
if xp.isdtype(sos.dtype, ("integral", "bool")):
4302+
sos = xp.astype(sos, xp.float64)
42644303

42654304
n_sections = sos.shape[0]
4266-
zi = np.empty((n_sections, 2), dtype=sos.dtype)
4305+
zi = xp.empty((n_sections, 2), dtype=sos.dtype)
42674306
scale = 1.0
42684307
for section in range(n_sections):
42694308
b = sos[section, :3]
42704309
a = sos[section, 3:]
4271-
zi[section] = scale * lfilter_zi(b, a)
4310+
zi[section, ...] = scale * lfilter_zi(b, a)
42724311
# If H(z) = B(z)/A(z) is this section's transfer function, then
42734312
# b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
42744313
# state value of this section's step response.
4275-
scale *= b.sum() / a.sum()
4314+
scale *= xp.sum(b) / xp.sum(a)
42764315

42774316
return zi
42784317

@@ -4614,6 +4653,8 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
46144653
2.875334415008979e-10
46154654
46164655
"""
4656+
xp = array_namespace(b, a, x)
4657+
46174658
b = np.atleast_1d(b)
46184659
a = np.atleast_1d(a)
46194660
x = np.asarray(x)
@@ -4623,7 +4664,7 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
46234664

46244665
if method == "gust":
46254666
y, z1, z2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
4626-
return y
4667+
return xp.asarray(y)
46274668

46284669
# method == "pad"
46294670
edge, ext = _validate_pad(padtype, padlen, x, axis,
@@ -4655,7 +4696,7 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
46554696
# Slice the actual signal from the extended signal.
46564697
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
46574698

4658-
return y
4699+
return xp.asarray(y)
46594700

46604701

46614702
def _validate_pad(padtype, padlen, x, axis, ntaps):
@@ -4769,10 +4810,17 @@ def sosfilt(sos, x, axis=-1, zi=None):
47694810
>>> plt.show()
47704811
47714812
"""
4772-
_reject_objects(sos, 'sosfilt')
4773-
_reject_objects(x, 'sosfilt')
4774-
if zi is not None:
4775-
_reject_objects(zi, 'sosfilt')
4813+
try:
4814+
xp = array_namespace(sos, x, zi)
4815+
except TypeError:
4816+
# either in1 or in2 are object arrays
4817+
xp = np_compat
4818+
4819+
if is_numpy(xp):
4820+
_reject_objects(sos, 'sosfilt')
4821+
_reject_objects(x, 'sosfilt')
4822+
if zi is not None:
4823+
_reject_objects(zi, 'sosfilt')
47764824

47774825
x = _validate_x(x)
47784826
sos, n_sections = _validate_sos(sos)
@@ -4786,7 +4834,12 @@ def sosfilt(sos, x, axis=-1, zi=None):
47864834
if dtype.char not in 'fdgFDGO':
47874835
raise NotImplementedError(f"input type '{dtype}' not supported")
47884836
if zi is not None:
4789-
zi = np.array(zi, dtype) # make a copy so that we can operate in place
4837+
zi = np.asarray(zi, dtype=dtype)
4838+
4839+
# make a copy so that we can operate in place
4840+
# NB: 1. use xp_copy to paper over numpy 1/2 copy= keyword
4841+
# 2. make sure the copied zi remains a numpy array
4842+
zi = xp_copy(zi, xp=array_namespace(zi))
47904843
if zi.shape != x_zi_shape:
47914844
raise ValueError('Invalid zi shape. With axis=%r, an input with '
47924845
'shape %r, and an sos array with %d sections, zi '
@@ -4798,7 +4851,7 @@ def sosfilt(sos, x, axis=-1, zi=None):
47984851
return_zi = False
47994852
axis = axis % x.ndim # make positive
48004853
x = np.moveaxis(x, axis, -1)
4801-
zi = np.moveaxis(zi, [0, axis + 1], [-2, -1])
4854+
zi = np.moveaxis(zi, (0, axis + 1), (-2, -1))
48024855
x_shape, zi_shape = x.shape, zi.shape
48034856
x = np.reshape(x, (-1, x.shape[-1]))
48044857
x = np.array(x, dtype, order='C') # make a copy, can modify in place
@@ -4809,10 +4862,10 @@ def sosfilt(sos, x, axis=-1, zi=None):
48094862
x = np.moveaxis(x, -1, axis)
48104863
if return_zi:
48114864
zi.shape = zi_shape
4812-
zi = np.moveaxis(zi, [-2, -1], [0, axis + 1])
4813-
out = (x, zi)
4865+
zi = np.moveaxis(zi, (-2, -1), (0, axis + 1))
4866+
out = (xp.asarray(x), xp.asarray(zi))
48144867
else:
4815-
out = x
4868+
out = xp.asarray(x)
48164869
return out
48174870

48184871

@@ -4905,6 +4958,8 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
49054958
>>> plt.show()
49064959
49074960
"""
4961+
xp = array_namespace(sos, x)
4962+
49084963
sos, n_sections = _validate_sos(sos)
49094964
x = _validate_x(x)
49104965

@@ -4926,7 +4981,7 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
49264981
y = axis_reverse(y, axis=axis)
49274982
if edge > 0:
49284983
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
4929-
return y
4984+
return xp.asarray(y)
49304985

49314986

49324987
def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True):

0 commit comments

Comments
 (0)