5
5
6
6
import numpy as np
7
7
from numpy import (atleast_1d , poly , polyval , roots , real , asarray ,
8
- resize , pi , absolute , sqrt , tan , log10 ,
8
+ pi , absolute , sqrt , tan , log10 ,
9
9
arcsinh , sin , exp , cosh , arccosh , ceil , conjugate ,
10
10
zeros , sinh , append , concatenate , prod , ones , full , array ,
11
11
mintypecode )
17
17
from scipy ._lib ._util import float_factorial
18
18
from scipy .signal ._arraytools import _validate_fs
19
19
20
+ import scipy ._lib .array_api_extra as xpx
21
+ from scipy ._lib ._array_api import array_namespace , xp_promote , xp_size
22
+
20
23
21
24
__all__ = ['findfreqs' , 'freqs' , 'freqz' , 'tf2zpk' , 'zpk2tf' , 'normalize' ,
22
25
'lp2lp' , 'lp2hp' , 'lp2bp' , 'lp2bs' , 'bilinear' , 'iirdesign' ,
@@ -1676,7 +1679,7 @@ def idx_worst(p):
1676
1679
return sos
1677
1680
1678
1681
1679
- def _align_nums (nums ):
1682
+ def _align_nums (nums , xp ):
1680
1683
"""Aligns the shapes of multiple numerators.
1681
1684
1682
1685
Given an array of numerator coefficient arrays [[a_1, a_2,...,
@@ -1701,19 +1704,19 @@ def _align_nums(nums):
1701
1704
# The statement can throw a ValueError if one
1702
1705
# of the numerators is a single digit and another
1703
1706
# is array-like e.g. if nums = [5, [1, 2, 3]]
1704
- nums = asarray (nums )
1707
+ nums = xp . asarray (nums )
1705
1708
1706
- if not np . issubdtype (nums .dtype , np . number ):
1709
+ if not xp . isdtype (nums .dtype , "numeric" ):
1707
1710
raise ValueError ("dtype of numerator is non-numeric" )
1708
1711
1709
1712
return nums
1710
1713
1711
1714
except ValueError :
1712
- nums = [np . atleast_1d ( num ) for num in nums ]
1713
- max_width = max (num . size for num in nums )
1715
+ nums = [xpx . atleast_nd ( xp . asarray ( num ), ndim = 1 ) for num in nums ]
1716
+ max_width = max (xp_size ( num ) for num in nums )
1714
1717
1715
1718
# pre-allocate
1716
- aligned_nums = np .zeros ((len ( nums ) , max_width ))
1719
+ aligned_nums = xp .zeros ((nums . shape [ 0 ] , max_width ))
1717
1720
1718
1721
# Create numerators with padded zeros
1719
1722
for index , num in enumerate (nums ):
@@ -1722,6 +1725,26 @@ def _align_nums(nums):
1722
1725
return aligned_nums
1723
1726
1724
1727
1728
+ def _trim_zeros (filt , trim = 'fb' ):
1729
+ # https://github.com/numpy/numpy/blob/v2.1.0/numpy/lib/_function_base_impl.py#L1874-L1925
1730
+ first = 0
1731
+ trim = trim .upper ()
1732
+ if 'F' in trim :
1733
+ for i in filt :
1734
+ if i != 0. :
1735
+ break
1736
+ else :
1737
+ first = first + 1
1738
+ last = filt .shape [0 ]
1739
+ if 'B' in trim :
1740
+ for i in filt [::- 1 ]:
1741
+ if i != 0. :
1742
+ break
1743
+ else :
1744
+ last = last - 1
1745
+ return filt [first :last ]
1746
+
1747
+
1725
1748
def normalize (b , a ):
1726
1749
"""Normalize numerator/denominator of a continuous-time transfer function.
1727
1750
@@ -1778,30 +1801,33 @@ def normalize(b, a):
1778
1801
Badly conditioned filter coefficients (numerator): the results may be meaningless
1779
1802
1780
1803
"""
1781
- num , den = b , a
1804
+ xp = array_namespace ( b , a )
1782
1805
1783
- den = np .asarray (den )
1784
- den = np .atleast_1d (den )
1785
- num = np .atleast_2d (_align_nums (num ))
1806
+ den = xp .asarray (a )
1807
+ den = xpx .atleast_nd (den , ndim = 1 , xp = xp )
1808
+
1809
+ num = xp .asarray (b )
1810
+ num = xpx .atleast_nd (_align_nums (num , xp ), ndim = 2 , xp = xp )
1786
1811
1787
1812
if den .ndim != 1 :
1788
1813
raise ValueError ("Denominator polynomial must be rank-1 array." )
1789
1814
if num .ndim > 2 :
1790
1815
raise ValueError ("Numerator polynomial must be rank-1 or"
1791
1816
" rank-2 array." )
1792
- if np .all (den == 0 ):
1817
+ if xp .all (den == 0 ):
1793
1818
raise ValueError ("Denominator must have at least on nonzero element." )
1794
1819
1795
1820
# Trim leading zeros in denominator, leave at least one.
1796
- den = np . trim_zeros (den , 'f' )
1821
+ den = _trim_zeros (den , 'f' )
1797
1822
1798
1823
# Normalize transfer function
1799
1824
num , den = num / den [0 ], den / den [0 ]
1800
1825
1801
1826
# Count numerator columns that are all zero
1802
1827
leading_zeros = 0
1803
- for col in num .T :
1804
- if np .allclose (col , 0 , atol = 1e-14 ):
1828
+ for j in range (num .shape [- 1 ]):
1829
+ col = num [:, j ]
1830
+ if xp .all (xp .abs (col ) <= 1e-14 ):
1805
1831
leading_zeros += 1
1806
1832
else :
1807
1833
break
@@ -1879,22 +1905,49 @@ def lp2lp(b, a, wo=1.0):
1879
1905
>>> plt.legend()
1880
1906
1881
1907
"""
1882
- a , b = map (atleast_1d , (a , b ))
1908
+ xp = array_namespace (a , b )
1909
+ a , b = map (xp .asarray , (a , b ))
1910
+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
1911
+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
1912
+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
1913
+
1883
1914
try :
1884
1915
wo = float (wo )
1885
1916
except TypeError :
1886
1917
wo = float (wo [0 ])
1887
- d = len ( a )
1888
- n = len ( b )
1918
+ d = a . shape [ 0 ]
1919
+ n = b . shape [ 0 ]
1889
1920
M = max ((d , n ))
1890
- pwo = pow ( wo , np .arange (M - 1 , - 1 , - 1 ) )
1921
+ pwo = wo ** xp .arange (M - 1 , - 1 , - 1 , dtype = xp . float64 )
1891
1922
start1 = max ((n - d , 0 ))
1892
1923
start2 = max ((d - n , 0 ))
1893
1924
b = b * pwo [start1 ] / pwo [start2 :]
1894
1925
a = a * pwo [start1 ] / pwo [start1 :]
1895
1926
return normalize (b , a )
1896
1927
1897
1928
1929
+ def _resize (a , new_shape , xp ):
1930
+ # https://github.com/numpy/numpy/blob/v2.2.4/numpy/_core/fromnumeric.py#L1535
1931
+ a = xp .reshape (a , (- 1 ,))
1932
+
1933
+ new_size = 1
1934
+ for dim_length in new_shape :
1935
+ new_size *= dim_length
1936
+ if dim_length < 0 :
1937
+ raise ValueError (
1938
+ 'all elements of `new_shape` must be non-negative'
1939
+ )
1940
+
1941
+ if xp_size (a ) == 0 or new_size == 0 :
1942
+ # First case must zero fill. The second would have repeats == 0.
1943
+ return xp .zeros_like (a , shape = new_shape )
1944
+
1945
+ repeats = - (- new_size // xp_size (a )) # ceil division
1946
+ a = xp .concat ((a ,) * repeats )[:new_size ]
1947
+
1948
+ return xp .reshape (a , new_shape )
1949
+
1950
+
1898
1951
def lp2hp (b , a , wo = 1.0 ):
1899
1952
r"""
1900
1953
Transform a lowpass filter prototype to a highpass filter.
@@ -1953,27 +2006,34 @@ def lp2hp(b, a, wo=1.0):
1953
2006
>>> plt.legend()
1954
2007
1955
2008
"""
1956
- a , b = map (atleast_1d , (a , b ))
2009
+ xp = array_namespace (a , b )
2010
+
2011
+ a , b = map (xp .asarray , (a , b ))
2012
+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2013
+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2014
+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
2015
+
1957
2016
try :
1958
2017
wo = float (wo )
1959
2018
except TypeError :
1960
2019
wo = float (wo [0 ])
1961
- d = len ( a )
1962
- n = len ( b )
2020
+ d = a . shape [ 0 ]
2021
+ n = b . shape [ 0 ]
1963
2022
if wo != 1 :
1964
- pwo = pow ( wo , np .arange (max ((d , n ))) )
2023
+ pwo = wo ** xp .arange (max ((d , n )), dtype = xp . float64 )
1965
2024
else :
1966
- pwo = np .ones (max ((d , n )), b .dtype . char )
2025
+ pwo = xp .ones (max ((d , n )), dtype = b .dtype )
1967
2026
if d >= n :
1968
- outa = a [::- 1 ] * pwo
1969
- outb = resize (b , (d ,))
2027
+ outa = xp .flip (a ) * pwo
2028
+ outb = xp .concat ((xp .zeros (n , dtype = b .dtype ), ))
2029
+ outb = _resize (b , (d ,), xp = xp )
1970
2030
outb [n :] = 0.0
1971
- outb [:n ] = b [:: - 1 ] * pwo [:n ]
2031
+ outb [:n ] = xp . flip ( b ) * pwo [:n ]
1972
2032
else :
1973
- outb = b [:: - 1 ] * pwo
1974
- outa = resize (a , (n ,))
2033
+ outb = xp . flip ( b ) * pwo
2034
+ outa = _resize (a , (n ,), xp = xp )
1975
2035
outa [d :] = 0.0
1976
- outa [:d ] = a [:: - 1 ] * pwo [:d ]
2036
+ outa [:d ] = xp . flip ( a ) * pwo [:d ]
1977
2037
1978
2038
return normalize (outb , outa )
1979
2039
@@ -2038,16 +2098,20 @@ def lp2bp(b, a, wo=1.0, bw=1.0):
2038
2098
>>> plt.ylabel('Amplitude [dB]')
2039
2099
>>> plt.legend()
2040
2100
"""
2101
+ xp = array_namespace (a , b )
2102
+
2103
+ a , b = map (xp .asarray , (a , b ))
2104
+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2105
+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2106
+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
2041
2107
2042
- a , b = map (atleast_1d , (a , b ))
2043
- D = len (a ) - 1
2044
- N = len (b ) - 1
2045
- artype = mintypecode ((a , b ))
2108
+ D = a .shape [0 ] - 1
2109
+ N = b .shape [0 ] - 1
2046
2110
ma = max ([N , D ])
2047
2111
Np = N + ma
2048
2112
Dp = D + ma
2049
- bprime = np .empty (Np + 1 , artype )
2050
- aprime = np .empty (Dp + 1 , artype )
2113
+ bprime = xp .empty (Np + 1 , dtype = b . dtype )
2114
+ aprime = xp .empty (Dp + 1 , dtype = a . dtype )
2051
2115
wosq = wo * wo
2052
2116
for j in range (Np + 1 ):
2053
2117
val = 0.0
@@ -2126,15 +2190,20 @@ def lp2bs(b, a, wo=1.0, bw=1.0):
2126
2190
>>> plt.ylabel('Amplitude [dB]')
2127
2191
>>> plt.legend()
2128
2192
"""
2129
- a , b = map (atleast_1d , (a , b ))
2130
- D = len (a ) - 1
2131
- N = len (b ) - 1
2132
- artype = mintypecode ((a , b ))
2193
+ xp = array_namespace (a , b )
2194
+
2195
+ a , b = map (xp .asarray , (a , b ))
2196
+ a , b = xp_promote (a , b , force_floating = True , xp = xp )
2197
+ a = xpx .atleast_nd (a , ndim = 1 , xp = xp )
2198
+ b = xpx .atleast_nd (b , ndim = 1 , xp = xp )
2199
+
2200
+ D = a .shape [0 ] - 1
2201
+ N = b .shape [0 ] - 1
2133
2202
M = max ([N , D ])
2134
2203
Np = M + M
2135
2204
Dp = M + M
2136
- bprime = np .empty (Np + 1 , artype )
2137
- aprime = np .empty (Dp + 1 , artype )
2205
+ bprime = xp .empty (Np + 1 , dtype = b . dtype )
2206
+ aprime = xp .empty (Dp + 1 , dtype = a . dtype )
2138
2207
wosq = wo * wo
2139
2208
for j in range (Np + 1 ):
2140
2209
val = 0.0
0 commit comments