28
28
29
29
from scipy ._lib ._array_api import (
30
30
array_namespace , is_torch , is_numpy , xp_copy , xp_size
31
+
31
32
)
32
33
import scipy ._lib .array_api_compat .numpy as np_compat
33
34
import scipy ._lib .array_api_extra as xpx
34
35
36
+
35
37
__all__ = ['correlate' , 'correlation_lags' , 'correlate2d' ,
36
38
'convolve' , 'convolve2d' , 'fftconvolve' , 'oaconvolve' ,
37
39
'order_filter' , 'medfilt' , 'medfilt2d' , 'wiener' , 'lfilter' ,
@@ -2192,12 +2194,22 @@ def lfilter(b, a, x, axis=-1, zi=None):
2192
2194
>>> plt.show()
2193
2195
2194
2196
"""
2197
+ try :
2198
+ xp = array_namespace (b , a , x , zi )
2199
+ except TypeError :
2200
+ # either in1 or in2 are object arrays
2201
+ xp = np_compat
2202
+
2203
+ if is_numpy (xp ):
2204
+ _reject_objects (x , 'lfilter' )
2205
+ _reject_objects (a , 'lfilter' )
2206
+ _reject_objects (b , 'lfilter' )
2207
+
2195
2208
b = np .atleast_1d (b )
2196
2209
a = np .atleast_1d (a )
2197
-
2198
- _reject_objects (x , 'lfilter' )
2199
- _reject_objects (a , 'lfilter' )
2200
- _reject_objects (b , 'lfilter' )
2210
+ x = np .asarray (x )
2211
+ if zi is not None :
2212
+ zi = np .asarray (zi )
2201
2213
2202
2214
if len (a ) == 1 :
2203
2215
# This path only supports types fdgFDGO to mirror _linear_filter below.
@@ -2256,16 +2268,18 @@ def lfilter(b, a, x, axis=-1, zi=None):
2256
2268
out = out_full [tuple (ind )]
2257
2269
2258
2270
if zi is None :
2259
- return out
2271
+ return xp . asarray ( out )
2260
2272
else :
2261
2273
ind [axis ] = slice (out_full .shape [axis ] - len (b ) + 1 , None )
2262
2274
zf = out_full [tuple (ind )]
2263
- return out , zf
2275
+ return xp . asarray ( out ), xp . asarray ( zf )
2264
2276
else :
2265
2277
if zi is None :
2266
- return _sigtools ._linear_filter (b , a , x , axis )
2278
+ result = _sigtools ._linear_filter (b , a , x , axis )
2279
+ return xp .asarray (result )
2267
2280
else :
2268
- return _sigtools ._linear_filter (b , a , x , axis , zi )
2281
+ out , zf = _sigtools ._linear_filter (b , a , x , axis , zi )
2282
+ return xp .asarray (out ), xp .asarray (zf )
2269
2283
2270
2284
2271
2285
def lfiltic (b , a , y , x = None ):
@@ -2308,40 +2322,59 @@ def lfiltic(b, a, y, x=None):
2308
2322
lfilter, lfilter_zi
2309
2323
2310
2324
"""
2311
- N = np .size (a ) - 1
2312
- M = np .size (b ) - 1
2325
+ try :
2326
+ xp = array_namespace (a , b , y , x )
2327
+ except TypeError :
2328
+ xp = np_compat
2329
+
2330
+ if is_numpy (xp ):
2331
+ _reject_objects (a , 'lfiltic' )
2332
+ _reject_objects (b , 'lfiltic' )
2333
+ _reject_objects (y , 'lfiltic' )
2334
+ if x is not None :
2335
+ _reject_objects (x , 'lfiltic' )
2336
+
2337
+ a = xp .asarray (a )
2338
+ b = xp .asarray (b )
2339
+
2340
+ N = xp_size (a ) - 1
2341
+ M = xp_size (b ) - 1
2313
2342
K = max (M , N )
2314
- y = np .asarray (y )
2343
+ y = xp .asarray (y )
2315
2344
2316
2345
if x is None :
2317
- result_type = np .result_type (np . asarray ( b ), np . asarray ( a ) , y )
2318
- if result_type . kind in 'bui' :
2319
- result_type = np .float64
2320
- x = np .zeros (M , dtype = result_type )
2346
+ result_type = xp .result_type (b , a , y )
2347
+ if xp . isdtype ( result_type , ( 'bool' , 'integral' )): # 'bui':
2348
+ result_type = xp .float64
2349
+ x = xp .zeros (M , dtype = result_type )
2321
2350
else :
2322
- x = np .asarray (x )
2351
+ x = xp .asarray (x )
2323
2352
2324
- result_type = np .result_type (np . asarray ( b ), np . asarray ( a ) , y , x )
2325
- if result_type . kind in 'bui' :
2326
- result_type = np .float64
2327
- x = x .astype (result_type )
2353
+ result_type = xp .result_type (b , a , y , x )
2354
+ if xp . isdtype ( result_type , ( 'bool' , 'integral' )): # 'bui':
2355
+ result_type = xp .float64
2356
+ x = xp .astype (x , result_type )
2328
2357
2329
- L = np .size (x )
2358
+ concat = array_namespace (a ).concat
2359
+
2360
+ L = xp_size (x )
2330
2361
if L < M :
2331
- x = np .r_ [x , np .zeros (M - L )]
2362
+ x = concat ((x , xp .zeros (M - L )))
2363
+
2364
+ y = xp .astype (y , result_type )
2365
+ zi = xp .zeros (K , dtype = result_type )
2332
2366
2333
- y = y .astype (result_type )
2334
- zi = np .zeros (K , result_type )
2367
+ concat = array_namespace (xp .ones (3 )).concat
2335
2368
2336
- L = np . size (y )
2369
+ L = xp_size (y )
2337
2370
if L < N :
2338
- y = np . r_ [ y , np .zeros (N - L )]
2371
+ y = concat (( y , np .zeros (N - L )))
2339
2372
2340
2373
for m in range (M ):
2341
- zi [m ] = np .sum (b [m + 1 :] * x [:M - m ], axis = 0 )
2374
+ zi [m ] = xp .sum (b [m + 1 :] * x [:M - m ], axis = 0 )
2342
2375
2343
2376
for m in range (N ):
2344
- zi [m ] -= np .sum (a [m + 1 :] * y [:N - m ], axis = 0 )
2377
+ zi [m ] -= xp .sum (a [m + 1 :] * y [:N - m ], axis = 0 )
2345
2378
2346
2379
return zi
2347
2380
@@ -2387,19 +2420,21 @@ def deconvolve(signal, divisor):
2387
2420
array([ 0., 1., 0., 0., 1., 1., 0., 0.])
2388
2421
2389
2422
"""
2390
- num = np .atleast_1d (signal )
2391
- den = np .atleast_1d (divisor )
2423
+ xp = array_namespace (signal , divisor )
2424
+
2425
+ num = xpx .atleast_nd (xp .asarray (signal ), ndim = 1 , xp = xp )
2426
+ den = xpx .atleast_nd (xp .asarray (divisor ), ndim = 1 , xp = xp )
2392
2427
if num .ndim > 1 :
2393
2428
raise ValueError ("signal must be 1-D." )
2394
2429
if den .ndim > 1 :
2395
2430
raise ValueError ("divisor must be 1-D." )
2396
- N = len ( num )
2397
- D = len ( den )
2431
+ N = num . shape [ 0 ]
2432
+ D = den . shape [ 0 ]
2398
2433
if D > N :
2399
2434
quot = []
2400
2435
rem = num
2401
2436
else :
2402
- input = np .zeros (N - D + 1 , float )
2437
+ input = xp .zeros (N - D + 1 , dtype = xp . float64 )
2403
2438
input [0 ] = 1
2404
2439
quot = lfilter (num , den , input )
2405
2440
rem = num - convolve (den , quot , mode = 'full' )
@@ -2550,8 +2585,7 @@ def hilbert2(x, N=None):
2550
2585
2551
2586
"""
2552
2587
xp = array_namespace (x )
2553
-
2554
- x = xpx .atleast_nd (x , ndim = 2 , xp = xp )
2588
+ x = xpx .atleast_nd (xp .asarray (x ), ndim = 2 , xp = xp )
2555
2589
if x .ndim > 2 :
2556
2590
raise ValueError ("x must be 2-D." )
2557
2591
if xp .isdtype (x .dtype , 'complex floating' ):
@@ -4147,6 +4181,7 @@ def lfilter_zi(b, a):
4147
4181
transient until the input drops from 0.5 to 0.0.
4148
4182
4149
4183
"""
4184
+ xp = array_namespace (b , a )
4150
4185
4151
4186
# FIXME: Can this function be replaced with an appropriate
4152
4187
# use of lfiltic? For example, when b,a = butter(N,Wn),
@@ -4156,35 +4191,37 @@ def lfilter_zi(b, a):
4156
4191
# We could use scipy.signal.normalize, but it uses warnings in
4157
4192
# cases where a ValueError is more appropriate, and it allows
4158
4193
# b to be 2D.
4159
- b = np . atleast_1d ( b )
4194
+ b = xpx . atleast_nd ( xp . asarray ( b ), ndim = 1 , xp = xp )
4160
4195
if b .ndim != 1 :
4161
4196
raise ValueError ("Numerator b must be 1-D." )
4162
- a = np . atleast_1d ( a )
4197
+ a = xpx . atleast_nd ( xp . asarray ( a ), ndim = 1 , xp = xp )
4163
4198
if a .ndim != 1 :
4164
4199
raise ValueError ("Denominator a must be 1-D." )
4165
4200
4166
- while len ( a ) > 1 and a [0 ] == 0.0 :
4201
+ while a . shape [ 0 ] > 1 and a [0 ] == 0.0 :
4167
4202
a = a [1 :]
4168
- if a . size < 1 :
4203
+ if xp_size ( a ) < 1 :
4169
4204
raise ValueError ("There must be at least one nonzero `a` coefficient." )
4170
4205
4171
4206
if a [0 ] != 1.0 :
4172
4207
# Normalize the coefficients so a[0] == 1.
4173
4208
b = b / a [0 ]
4174
4209
a = a / a [0 ]
4175
4210
4176
- n = max (len ( a ), len ( b ) )
4211
+ n = max (a . shape [ 0 ], b . shape [ 0 ] )
4177
4212
4178
4213
# Pad a or b with zeros so they are the same length.
4179
- if len (a ) < n :
4180
- a = np .r_ [a , np .zeros (n - len (a ), dtype = a .dtype )]
4181
- elif len (b ) < n :
4182
- b = np .r_ [b , np .zeros (n - len (b ), dtype = b .dtype )]
4183
-
4184
- IminusA = np .eye (n - 1 , dtype = np .result_type (a , b )) - linalg .companion (a ).T
4214
+ if a .shape [0 ] < n :
4215
+ a = xp .concat ((a , xp .zeros (n - a .shape [0 ], dtype = a .dtype )))
4216
+ elif b .shape [0 ] < n :
4217
+ b = xp .concat ((b , xp .zeros (n - b .shape [0 ], dtype = b .dtype )))
4218
+
4219
+ dt = xp .result_type (a , b )
4220
+ IminusA = np .eye (n - 1 ) - linalg .companion (a ).T
4221
+ IminusA = xp .asarray (IminusA , dtype = dt )
4185
4222
B = b [1 :] - a [1 :] * b [0 ]
4186
4223
# Solve zi = A*zi + B
4187
- zi = np .linalg .solve (IminusA , B )
4224
+ zi = xp .linalg .solve (IminusA , B )
4188
4225
4189
4226
# For future reference: we could also use the following
4190
4227
# explicit formulas to solve the linear system:
@@ -4255,24 +4292,26 @@ def sosfilt_zi(sos):
4255
4292
>>> plt.show()
4256
4293
4257
4294
"""
4258
- sos = np .asarray (sos )
4295
+ xp = array_namespace (sos )
4296
+
4297
+ sos = xp .asarray (sos )
4259
4298
if sos .ndim != 2 or sos .shape [1 ] != 6 :
4260
4299
raise ValueError ('sos must be shape (n_sections, 6)' )
4261
4300
4262
- if sos .dtype . kind in 'bui' :
4263
- sos = sos .astype (np .float64 )
4301
+ if xp . isdtype ( sos .dtype , ( "integral" , "bool" )) :
4302
+ sos = xp .astype (sos , xp .float64 )
4264
4303
4265
4304
n_sections = sos .shape [0 ]
4266
- zi = np .empty ((n_sections , 2 ), dtype = sos .dtype )
4305
+ zi = xp .empty ((n_sections , 2 ), dtype = sos .dtype )
4267
4306
scale = 1.0
4268
4307
for section in range (n_sections ):
4269
4308
b = sos [section , :3 ]
4270
4309
a = sos [section , 3 :]
4271
- zi [section ] = scale * lfilter_zi (b , a )
4310
+ zi [section , ... ] = scale * lfilter_zi (b , a )
4272
4311
# If H(z) = B(z)/A(z) is this section's transfer function, then
4273
4312
# b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
4274
4313
# state value of this section's step response.
4275
- scale *= b .sum () / a .sum ()
4314
+ scale *= xp .sum (b ) / xp .sum (a )
4276
4315
4277
4316
return zi
4278
4317
@@ -4614,6 +4653,8 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4614
4653
2.875334415008979e-10
4615
4654
4616
4655
"""
4656
+ xp = array_namespace (b , a , x )
4657
+
4617
4658
b = np .atleast_1d (b )
4618
4659
a = np .atleast_1d (a )
4619
4660
x = np .asarray (x )
@@ -4623,7 +4664,7 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4623
4664
4624
4665
if method == "gust" :
4625
4666
y , z1 , z2 = _filtfilt_gust (b , a , x , axis = axis , irlen = irlen )
4626
- return y
4667
+ return xp . asarray ( y )
4627
4668
4628
4669
# method == "pad"
4629
4670
edge , ext = _validate_pad (padtype , padlen , x , axis ,
@@ -4655,7 +4696,7 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4655
4696
# Slice the actual signal from the extended signal.
4656
4697
y = axis_slice (y , start = edge , stop = - edge , axis = axis )
4657
4698
4658
- return y
4699
+ return xp . asarray ( y )
4659
4700
4660
4701
4661
4702
def _validate_pad (padtype , padlen , x , axis , ntaps ):
@@ -4769,10 +4810,17 @@ def sosfilt(sos, x, axis=-1, zi=None):
4769
4810
>>> plt.show()
4770
4811
4771
4812
"""
4772
- _reject_objects (sos , 'sosfilt' )
4773
- _reject_objects (x , 'sosfilt' )
4774
- if zi is not None :
4775
- _reject_objects (zi , 'sosfilt' )
4813
+ try :
4814
+ xp = array_namespace (sos , x , zi )
4815
+ except TypeError :
4816
+ # either in1 or in2 are object arrays
4817
+ xp = np_compat
4818
+
4819
+ if is_numpy (xp ):
4820
+ _reject_objects (sos , 'sosfilt' )
4821
+ _reject_objects (x , 'sosfilt' )
4822
+ if zi is not None :
4823
+ _reject_objects (zi , 'sosfilt' )
4776
4824
4777
4825
x = _validate_x (x )
4778
4826
sos , n_sections = _validate_sos (sos )
@@ -4786,7 +4834,12 @@ def sosfilt(sos, x, axis=-1, zi=None):
4786
4834
if dtype .char not in 'fdgFDGO' :
4787
4835
raise NotImplementedError (f"input type '{ dtype } ' not supported" )
4788
4836
if zi is not None :
4789
- zi = np .array (zi , dtype ) # make a copy so that we can operate in place
4837
+ zi = np .asarray (zi , dtype = dtype )
4838
+
4839
+ # make a copy so that we can operate in place
4840
+ # NB: 1. use xp_copy to paper over numpy 1/2 copy= keyword
4841
+ # 2. make sure the copied zi remains a numpy array
4842
+ zi = xp_copy (zi , xp = array_namespace (zi ))
4790
4843
if zi .shape != x_zi_shape :
4791
4844
raise ValueError ('Invalid zi shape. With axis=%r, an input with '
4792
4845
'shape %r, and an sos array with %d sections, zi '
@@ -4798,7 +4851,7 @@ def sosfilt(sos, x, axis=-1, zi=None):
4798
4851
return_zi = False
4799
4852
axis = axis % x .ndim # make positive
4800
4853
x = np .moveaxis (x , axis , - 1 )
4801
- zi = np .moveaxis (zi , [ 0 , axis + 1 ], [ - 2 , - 1 ] )
4854
+ zi = np .moveaxis (zi , ( 0 , axis + 1 ), ( - 2 , - 1 ) )
4802
4855
x_shape , zi_shape = x .shape , zi .shape
4803
4856
x = np .reshape (x , (- 1 , x .shape [- 1 ]))
4804
4857
x = np .array (x , dtype , order = 'C' ) # make a copy, can modify in place
@@ -4809,10 +4862,10 @@ def sosfilt(sos, x, axis=-1, zi=None):
4809
4862
x = np .moveaxis (x , - 1 , axis )
4810
4863
if return_zi :
4811
4864
zi .shape = zi_shape
4812
- zi = np .moveaxis (zi , [ - 2 , - 1 ], [ 0 , axis + 1 ] )
4813
- out = (x , zi )
4865
+ zi = np .moveaxis (zi , ( - 2 , - 1 ), ( 0 , axis + 1 ) )
4866
+ out = (xp . asarray ( x ), xp . asarray ( zi ) )
4814
4867
else :
4815
- out = x
4868
+ out = xp . asarray ( x )
4816
4869
return out
4817
4870
4818
4871
@@ -4905,6 +4958,8 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
4905
4958
>>> plt.show()
4906
4959
4907
4960
"""
4961
+ xp = array_namespace (sos , x )
4962
+
4908
4963
sos , n_sections = _validate_sos (sos )
4909
4964
x = _validate_x (x )
4910
4965
@@ -4926,7 +4981,7 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
4926
4981
y = axis_reverse (y , axis = axis )
4927
4982
if edge > 0 :
4928
4983
y = axis_slice (y , start = edge , stop = - edge , axis = axis )
4929
- return y
4984
+ return xp . asarray ( y )
4930
4985
4931
4986
4932
4987
def decimate (x , q , n = None , ftype = 'iir' , axis = - 1 , zero_phase = True ):
0 commit comments