Skip to content

Commit 4d2a0f3

Browse files
authored
ENH: integrate.nsum: support unimodal functions and infinite lower limit of summation (scipy#21789)
* ENH: integrate.nsum: support infinite lower bounds * ENH: integrate.nsum: extend to unimodal functions
1 parent 04d05f3 commit 4d2a0f3

File tree

2 files changed

+123
-30
lines changed

2 files changed

+123
-30
lines changed

scipy/integrate/_tanhsinh.py

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,9 @@ def _nsum_iv(f, a, b, step, args, log, maxterms, tolerances):
902902
if not np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.complexfloating):
903903
raise ValueError(message)
904904

905-
valid_a = np.isfinite(a)
906905
valid_b = b >= a # NaNs will be False
907906
valid_step = np.isfinite(step) & (step > 0)
908-
valid_abstep = valid_a & valid_b & valid_step
907+
valid_abstep = valid_b & valid_step
909908

910909
message = '`log` must be True or False.'
911910
if log not in {True, False}:
@@ -948,16 +947,16 @@ def _nsum_iv(f, a, b, step, args, log, maxterms, tolerances):
948947

949948

950949
def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances=None):
951-
r"""Evaluate a convergent, monotonically decreasing finite or infinite series.
950+
r"""Evaluate a convergent finite or infinite series.
952951
953-
For finite `b`, this evaluates::
952+
For finite `a` and `b`, this evaluates::
954953
955954
f(a + np.arange(n)*step).sum()
956955
957956
where ``n = int((b - a) / step) + 1``, where `f` is smooth, positive, and
958-
monotone decreasing. `b` may be very large or infinite, in which case
959-
a partial sum is evaluated directly and the remainder is approximated using
960-
integration.
957+
unimodal. The number of terms in the sum may be very large or infinite,
958+
in which case a partial sum is evaluated directly and the remainder is
959+
approximated using integration.
961960
962961
Parameters
963962
----------
@@ -975,14 +974,11 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
975974
array ``x`` or the arrays in ``args``, and it must return NaN where
976975
the argument is NaN.
977976
978-
`f` must represent a smooth, positive, and monotone decreasing
979-
function of `x` defined at *all reals* between `a` and `b`.
980-
`nsum` performs no checks to verify that these conditions
981-
are met and may return erroneous results if they are violated.
977+
`f` must represent a smooth, positive, unimodal function of `x` defined at
978+
*all reals* between `a` and `b`.
982979
a, b : float array_like
983980
Real lower and upper limits of summed terms. Must be broadcastable.
984-
Each element of `a` must be finite and less than the corresponding
985-
element in `b`, but elements of `b` may be infinite.
981+
Each element of `a` must be less than the corresponding element in `b`.
986982
step : float array_like
987983
Finite, positive, real step between summed terms. Must be broadcastable
988984
with `a` and `b`. Note that the number of terms included in the sum will
@@ -1033,6 +1029,9 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
10331029
- ``-4`` : The magnitude of the last term of the partial sum exceeds
10341030
the tolerances, so the error estimate exceeds the tolerances.
10351031
Consider increasing `maxterms` or loosening `tolerances`.
1032+
Alternatively, the callable may not be unimodal, or the limits of
1033+
summation may be too far from the function maximum. Consider
1034+
increasing `maxterms` or breaking the sum into pieces.
10361035
10371036
sum : float array
10381037
An estimate of the sum.
@@ -1077,12 +1076,20 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
10771076
that appear in the sum has little effect. If there is not a natural
10781077
extension of the function to all reals, consider using linear interpolation,
10791078
which is easy to evaluate and preserves monotonicity.
1080-
1079+
10811080
The approach described above is generalized for non-unit
10821081
`step` and finite `b` that is too large for direct evaluation of the sum,
1083-
i.e. ``b - a + 1 > maxterms``.
1082+
i.e. ``b - a + 1 > maxterms``. It is further generalized to unimodal
1083+
functions by directly summing terms surrounding the maximum.
1084+
This strategy may fail:
1085+
1086+
- If the left limit is finite and the maximum is far from it.
1087+
- If the right limit is finite and the maximum is far from it.
1088+
- If both limits are finite and the maximum is far from the origin.
10841089
1085-
Although the callable `f` must be non-negative and monotonically decreasing,
1090+
In these cases, accuracy may be poor, and `nsum` may return status code ``4``.
1091+
1092+
Although the callable `f` must be non-negative and unimodal,
10861093
`nsum` can be used to evaluate more general forms of series. For instance, to
10871094
evaluate an alternating series, pass a callable that returns the difference
10881095
between pairs of adjacent terms, and adjust `step` accordingly. See Examples.
@@ -1127,7 +1134,6 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
11271134
# Potential future work:
11281135
# - improve error estimate of `_direct` sum
11291136
# - add other methods for convergence acceleration (Richardson, epsilon)
1130-
# - support infinite lower limit?
11311137
# - support negative monotone increasing functions?
11321138
# - b < a / negative step?
11331139
# - complex-valued function?
@@ -1147,7 +1153,8 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
11471153
step = np.broadcast_to(step, shape).ravel().astype(dtype)
11481154
valid_abstep = np.broadcast_to(valid_abstep, shape).ravel()
11491155
nterms = np.floor((b - a) / step)
1150-
b = a + nterms*step
1156+
finite_terms = np.isfinite(nterms)
1157+
b[finite_terms] = a[finite_terms] + nterms[finite_terms]*step[finite_terms]
11511158

11521159
# Define constants
11531160
eps = np.finfo(dtype).eps
@@ -1163,9 +1170,15 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
11631170
nfev = np.ones(len(a), dtype=int) # one function evaluation above
11641171

11651172
# Branch for direct sum evaluation / integral approximation / invalid input
1166-
i1 = (nterms + 1 <= maxterms) & valid_abstep
1167-
i2 = (nterms + 1 > maxterms) & valid_abstep
1168-
i3 = ~valid_abstep
1173+
i0 = ~valid_abstep # invalid
1174+
i1 = (nterms + 1 <= maxterms) & ~i0 # direct sum evaluation
1175+
i2 = np.isfinite(a) & ~i1 & ~i0 # infinite sum to the right
1176+
i3 = np.isfinite(b) & ~i2 & ~i1 & ~i0 # infinite sum to the left
1177+
i4 = ~i3 & ~i2 & ~i1 & ~i0 # infinite sum on both sides
1178+
1179+
if np.any(i0):
1180+
S[i0], E[i0] = np.nan, np.nan
1181+
status[i0] = -1
11691182

11701183
if np.any(i1):
11711184
args_direct = [arg[i1] for arg in args]
@@ -1181,8 +1194,40 @@ def nsum(f, a, b, *, step=1, args=(), log=False, maxterms=int(2**20), tolerances
11811194
nfev[i2] += tmp[-1]
11821195

11831196
if np.any(i3):
1184-
S[i3], E[i3] = np.nan, np.nan
1185-
status[i3] = -1
1197+
args_indirect = [arg[i3] for arg in args]
1198+
def _f(x, *args): return f(-x, *args)
1199+
tmp = _integral_bound(_f, -b[i3], -a[i3], step[i3], args_indirect, constants)
1200+
S[i3], E[i3], status[i3] = tmp[:-1]
1201+
nfev[i3] += tmp[-1]
1202+
1203+
if np.any(i4):
1204+
args_indirect = [arg[i4] for arg in args]
1205+
1206+
# There are two obvious high-level strategies:
1207+
# - Do two separate half-infinite sums (e.g. from -inf to 0 and 1 to inf)
1208+
# - Make a callable that returns f(x) + f(-x) and do a single half-infinite sum
1209+
# I thought the latter would have about half the overhead, so I went that way.
1210+
# Then there are two ways of ensuring that f(0) doesn't get counted twice.
1211+
# - Evaluate the sum from 1 to inf and add f(0)
1212+
# - Evaluate the sum from 0 to inf and subtract f(0)
1213+
# - Evaluate the sum from 0 to inf, but apply a weight of 0.5 when `x = 0`
1214+
# The last option has more overhead, but is simpler to implement correctly
1215+
# (especially getting the status message right)
1216+
if log:
1217+
def _f(x, *args):
1218+
log_factor = np.where(x==0, np.log(0.5), 0)
1219+
out = np.stack([f(x, *args), f(-x, *args)], axis=0)
1220+
return special.logsumexp(out, axis=0) + log_factor
1221+
1222+
else:
1223+
def _f(x, *args):
1224+
factor = np.where(x==0, 0.5, 1)
1225+
return (f(x, *args) + f(-x, *args)) * factor
1226+
1227+
zero = np.zeros_like(a[i4])
1228+
tmp = _integral_bound(_f, zero, b[i4], step[i4], args_indirect, constants)
1229+
S[i4], E[i4], status[i4] = tmp[:-1]
1230+
nfev[i4] += 2*tmp[-1]
11861231

11871232
# Return results
11881233
S, E = S.reshape(shape)[()], E.reshape(shape)[()]
@@ -1265,17 +1310,20 @@ def _integral_bound(f, a, b, step, args, constants):
12651310
# Find the location of a term that is less than the tolerance (if possible)
12661311
log2maxterms = np.floor(np.log2(maxterms)) if maxterms else 0
12671312
n_steps = np.concatenate([2**np.arange(0, log2maxterms), [maxterms]], dtype=dtype)
1268-
nfev = len(n_steps)
1313+
nfev = len(n_steps) * 2
12691314
ks = a2 + n_steps * step2
12701315
fks = f(ks, *args2)
1271-
n_fk_insufficient = np.sum(fks > tol[:, np.newaxis], axis=-1)
1316+
fksp1 = f(ks + step2, *args2) # check that the function is decreasing
1317+
fk_insufficient = (fks > tol[:, np.newaxis]) | (fksp1 > fks)
1318+
n_fk_insufficient = np.sum(fk_insufficient, axis=-1)
12721319
nt = np.minimum(n_fk_insufficient, n_steps.shape[-1]-1)
12731320
n_steps = n_steps[nt]
12741321

1275-
# If `maxterms` is insufficient (i.e. the magnitude of the last term of the
1276-
# partial sum exceeds the tolerance), we can finish the calculation and report
1277-
# valid sum and error estimates, but we'll have nonzero status.
1278-
i_fk_insufficient = (n_fk_insufficient == nfev)
1322+
# If `maxterms` is insufficient (i.e. either the magnitude of the last term of the
1323+
# partial sum exceeds the tolerance or the function is not decreasing), finish the
1324+
# calculation, but report nonzero status. (Improvement: separate the status codes
1325+
# for these two cases.)
1326+
i_fk_insufficient = (n_fk_insufficient == nfev//2)
12791327

12801328
# Directly evaluate the sum up to this term
12811329
k = a + n_steps * step

scipy/integrate/tests/test_tanhsinh.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def test_input_validation(self):
801801
nsum(f, f.a, f.b, tolerances=dict(rtol=pytest))
802802

803803
with np.errstate(all='ignore'):
804-
res = nsum(f, [np.nan, -np.inf, np.inf], 1)
804+
res = nsum(f, [np.nan, np.inf], 1)
805805
assert np.all((res.status == -1) & np.isnan(res.sum)
806806
& np.isnan(res.error) & ~res.success & res.nfev == 1)
807807
res = nsum(f, 10, [np.nan, 1])
@@ -971,6 +971,51 @@ def test_inclusive(self):
971971
assert np.all(res.sum > (ref.sum - res.error))
972972
assert np.all(res.sum < (ref.sum + res.error))
973973

974+
@pytest.mark.parametrize('log', [True, False])
975+
def test_infinite_bounds(self, log):
976+
a = [1, -np.inf, -np.inf]
977+
b = [np.inf, -1, np.inf]
978+
c = [1, 2, 3]
979+
980+
def f(x, a):
981+
return (np.log(np.tanh(a / 2)) - a*np.abs(x) if log
982+
else np.tanh(a/2) * np.exp(-a*np.abs(x)))
983+
984+
res = nsum(f, a, b, args=(c,), log=log)
985+
ref = [stats.dlaplace.sf(0, 1), stats.dlaplace.sf(0, 2), 1]
986+
ref = np.log(ref) if log else ref
987+
atol = 1e-10 if log else 0
988+
assert_allclose(res.sum, ref, atol=atol)
989+
990+
# Make sure the sign of `x` passed into `f` is correct.
991+
def f(x, c):
992+
return -3*np.log(c*x) if log else 1 / (c*x)**3
993+
994+
res = nsum(f, [1, -np.inf], [np.inf, -1], args=([1, -1],), log=log)
995+
ref = np.log(special.zeta(3)) if log else special.zeta(3)
996+
assert_allclose(res.sum, ref)
997+
998+
def test_decreasing_check(self):
999+
# Test accuracy when we start sum on an uphill slope.
1000+
# Without the decreasing check, the terms would look small enough to
1001+
# use the integral approximation. Because the function is not decreasing,
1002+
# the error is not bounded by the magnitude of the last term of the
1003+
# partial sum. In this case, the error would be ~1e-4, causing the test
1004+
# to fail.
1005+
def f(x):
1006+
return np.exp(-x ** 2)
1007+
1008+
res = nsum(f, -25, np.inf)
1009+
1010+
# Reference computed with mpmath:
1011+
# from mpmath import mp
1012+
# mp.dps = 50
1013+
# def fmp(x): return mp.exp(-x**2)
1014+
# ref = mp.nsum(fmp, (-25, 0)) + mp.nsum(fmp, (1, mp.inf))
1015+
ref = 1.772637204826652
1016+
1017+
np.testing.assert_allclose(res.sum, ref, rtol=1e-15)
1018+
9741019
def test_special_case(self):
9751020
# test equal lower/upper limit
9761021
f = self.f1

0 commit comments

Comments
 (0)