Skip to content

Commit 82b7a32

Browse files
j-bowhaymdhaber
andauthored
ENH: differentiate: add array API support to jacobian and hessian (scipy#21811)
* ENH: differentiate: add array API support to `jacobian` and `hessian` --------- Co-authored-by: Matt Haberland <[email protected]>
1 parent efe5fe7 commit 82b7a32

File tree

3 files changed

+112
-78
lines changed

3 files changed

+112
-78
lines changed

scipy/_lib/_elementwise_iterative_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
1010
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
1111
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
12-
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature.
12+
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature,
13+
# `scipy.differentiate.derivative` for finite difference based differentiation.
1314

1415
import math
1516
import numpy as np

scipy/differentiate/_differentiate.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import scipy._lib._elementwise_iterative_method as eim
55
from scipy._lib._util import _RichResult
6-
from scipy._lib._array_api import array_namespace, xp_sign
6+
from scipy._lib._array_api import array_namespace, xp_sign, xp_copy, xp_take_along_axis
77

88
_EERRORINCREASE = -1 # used in derivative
99

@@ -882,30 +882,34 @@ def jacobian(f, x, *, tolerances=None, maxiter=10, order=8, initial_step=0.5,
882882
True
883883
884884
"""
885-
x = np.asarray(x)
886-
int_dtype = np.issubdtype(x.dtype, np.integer)
887-
x0 = np.asarray(x, dtype=float) if int_dtype else x
885+
xp = array_namespace(x)
886+
x = xp.asarray(x)
887+
int_dtype = xp.isdtype(x.dtype, 'integral')
888+
x0 = xp.asarray(x, dtype=xp.asarray(1.0).dtype) if int_dtype else x
888889

889890
if x0.ndim < 1:
890891
message = "Argument `x` must be at least 1-D."
891892
raise ValueError(message)
892893

893894
m = x0.shape[0]
894-
i = np.arange(m)
895+
i = xp.arange(m)
895896

896897
def wrapped(x):
897898
p = () if x.ndim == x0.ndim else (x.shape[-1],) # number of abscissae
898-
new_dims = (1,) if x.ndim == x0.ndim else (1, -1)
899+
899900
new_shape = (m, m) + x0.shape[1:] + p
900-
xph = np.expand_dims(x0, new_dims)
901-
xph = np.broadcast_to(xph, new_shape).copy()
901+
xph = xp.expand_dims(x0, axis=1)
902+
if x.ndim != x0.ndim:
903+
xph = xp.expand_dims(xph, axis=-1)
904+
xph = xp_copy(xp.broadcast_to(xph, new_shape), xp=xp)
902905
xph[i, i] = x
903906
return f(xph)
904907

905908
res = derivative(wrapped, x, tolerances=tolerances,
906909
maxiter=maxiter, order=order, initial_step=initial_step,
907910
step_factor=step_factor, preserve_shape=True,
908911
step_direction=step_direction)
912+
909913
del res.x # the user knows `x`, and the way it gets broadcasted is meaningless here
910914
return res
911915

@@ -1069,9 +1073,10 @@ def hessian(f, x, *, tolerances=None, maxiter=10,
10691073
atol = tolerances.get('atol', None)
10701074
rtol = tolerances.get('rtol', None)
10711075

1072-
x = np.asarray(x)
1073-
dtype = x.dtype if np.issubdtype(x.dtype, np.inexact) else np.float64
1074-
finfo = np.finfo(dtype)
1076+
xp = array_namespace(x)
1077+
x = xp.asarray(x)
1078+
dtype = x.dtype if not xp.isdtype(x.dtype, 'integral') else xp.asarray(1.).dtype
1079+
finfo = xp.finfo(dtype)
10751080
rtol = finfo.eps**0.5 if rtol is None else rtol # keep same as `derivative`
10761081

10771082
# tighten the inner tolerance to make the inner error negligible
@@ -1091,8 +1096,9 @@ def df(x):
10911096
nfev = [] # track inner function evaluations
10921097
res = jacobian(df, x, tolerances=tolerances, **kwargs) # jacobian of jacobian
10931098

1094-
nfev = np.cumsum(nfev, axis=0)
1095-
res.nfev = np.take_along_axis(nfev, res.nit[np.newaxis, ...], axis=0)[0]
1099+
nfev = xp.cumulative_sum(xp.stack(nfev), axis=0)
1100+
res_nit = xp.astype(res.nit[xp.newaxis, ...], xp.int64) # appease torch
1101+
res.nfev = xp_take_along_axis(nfev, res_nit, axis=0)[0]
10961102
res.ddf = res.df
10971103
del res.df # this is renamed to ddf
10981104
del res.nit # this is only the outer-jacobian nit

0 commit comments

Comments
 (0)