Skip to content

Commit 2a15022

Browse files
authored
MAINT: interpolate.AAA: improve input validation of max_terms (scipy#21329)
1 parent 3b61e16 commit 2a15022

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

scipy/interpolate/_aaa.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

2626
import warnings
27+
import operator
2728

2829
import numpy as np
2930
import scipy
@@ -52,9 +53,10 @@ class AAA:
5253
rtol : float, optional
5354
Relative tolerance, defaults to ``eps**0.75``. If a small subset of the entries
5455
in `values` are much larger than the rest the default tolerance may be too
55-
loose.
56+
loose.
5657
max_terms : int, optional
5758
Maximum number of terms in the barycentric representation, defaults to ``100``.
59+
Must be greater than or equal to one.
5860
5961
Attributes
6062
----------
@@ -69,7 +71,7 @@ class AAA:
6971
errors : array
7072
Error :math:`|f(z) - r(z)|_\infty` over `points` in the successive iterations
7173
of AAA.
72-
74+
7375
Warns
7476
-----
7577
RuntimeWarning
@@ -93,7 +95,7 @@ class AAA:
9395
where :math:`z_1,\dots,z_m` are real or complex support points selected from
9496
`points`, :math:`f_1,\dots,f_m` are the corresponding real or complex data values
9597
from `values`, and :math:`w_1,\dots,w_m` are real or complex weights.
96-
98+
9799
Each iteration of the algorithm has two parts: the greedy selection the next support
98100
point and the computation of the weights. The first part of each iteration is to
99101
select the next support point to be added :math:`z_{m+1}` from the remaining
@@ -102,7 +104,7 @@ class AAA:
102104
when this maximum is less than ``rtol * np.linalg.norm(f, ord=np.inf)``. This means
103105
the interpolation property is only satisfied up to a tolerance, except at the
104106
support points where approximation exactly interpolates the supplied data.
105-
107+
106108
In the second part of each iteration, the weights :math:`w_j` are selected to solve
107109
the least-squares problem
108110
@@ -195,10 +197,15 @@ def __init__(self, points, values, *, rtol=None, max_terms=100):
195197

196198
if f.size != z.size:
197199
raise ValueError("`points` and `values` must be the same size.")
198-
200+
199201
if not np.all(np.isfinite(z)):
200202
raise ValueError("`points` must be finite.")
201203

204+
max_terms = operator.index(max_terms)
205+
if max_terms < 1:
206+
raise ValueError("`max_terms` must be an integer value greater than or "
207+
"equal to one.")
208+
202209
# Remove infinite or NaN function values and repeated entries
203210
to_keep = (np.isfinite(f)) & (~np.isnan(f))
204211
f = f[to_keep]

scipy/interpolate/tests/test_aaa.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def test_input_validation(self):
4343
AAA([[0], [0]], [[1], [1]])
4444
with pytest.raises(ValueError, match="finite"):
4545
AAA([np.inf], [1])
46+
with pytest.raises(TypeError):
47+
AAA([1], [1], max_terms=1.0)
48+
with pytest.raises(ValueError, match="greater"):
49+
AAA([1], [1], max_terms=-1)
4650

4751
def test_convergence_error(self):
4852
with pytest.warns(RuntimeWarning, match="AAA failed"):

0 commit comments

Comments
 (0)