3
3
import numpy as np
4
4
import scipy ._lib ._elementwise_iterative_method as eim
5
5
from scipy ._lib ._util import _RichResult
6
- from scipy ._lib ._array_api import array_namespace , xp_sign
6
+ from scipy ._lib ._array_api import array_namespace , xp_sign , xp_copy , xp_take_along_axis
7
7
8
8
_EERRORINCREASE = - 1 # used in derivative
9
9
@@ -882,30 +882,34 @@ def jacobian(f, x, *, tolerances=None, maxiter=10, order=8, initial_step=0.5,
882
882
True
883
883
884
884
"""
885
- x = np .asarray (x )
886
- int_dtype = np .issubdtype (x .dtype , np .integer )
887
- x0 = np .asarray (x , dtype = float ) if int_dtype else x
885
+ xp = array_namespace (x )
886
+ x = xp .asarray (x )
887
+ int_dtype = xp .isdtype (x .dtype , 'integral' )
888
+ x0 = xp .asarray (x , dtype = xp .asarray (1.0 ).dtype ) if int_dtype else x
888
889
889
890
if x0 .ndim < 1 :
890
891
message = "Argument `x` must be at least 1-D."
891
892
raise ValueError (message )
892
893
893
894
m = x0 .shape [0 ]
894
- i = np .arange (m )
895
+ i = xp .arange (m )
895
896
896
897
def wrapped (x ):
897
898
p = () if x .ndim == x0 .ndim else (x .shape [- 1 ],) # number of abscissae
898
- new_dims = ( 1 ,) if x . ndim == x0 . ndim else ( 1 , - 1 )
899
+
899
900
new_shape = (m , m ) + x0 .shape [1 :] + p
900
- xph = np .expand_dims (x0 , new_dims )
901
- xph = np .broadcast_to (xph , new_shape ).copy ()
901
+ xph = xp .expand_dims (x0 , axis = 1 )
902
+ if x .ndim != x0 .ndim :
903
+ xph = xp .expand_dims (xph , axis = - 1 )
904
+ xph = xp_copy (xp .broadcast_to (xph , new_shape ), xp = xp )
902
905
xph [i , i ] = x
903
906
return f (xph )
904
907
905
908
res = derivative (wrapped , x , tolerances = tolerances ,
906
909
maxiter = maxiter , order = order , initial_step = initial_step ,
907
910
step_factor = step_factor , preserve_shape = True ,
908
911
step_direction = step_direction )
912
+
909
913
del res .x # the user knows `x`, and the way it gets broadcasted is meaningless here
910
914
return res
911
915
@@ -1069,9 +1073,10 @@ def hessian(f, x, *, tolerances=None, maxiter=10,
1069
1073
atol = tolerances .get ('atol' , None )
1070
1074
rtol = tolerances .get ('rtol' , None )
1071
1075
1072
- x = np .asarray (x )
1073
- dtype = x .dtype if np .issubdtype (x .dtype , np .inexact ) else np .float64
1074
- finfo = np .finfo (dtype )
1076
+ xp = array_namespace (x )
1077
+ x = xp .asarray (x )
1078
+ dtype = x .dtype if not xp .isdtype (x .dtype , 'integral' ) else xp .asarray (1. ).dtype
1079
+ finfo = xp .finfo (dtype )
1075
1080
rtol = finfo .eps ** 0.5 if rtol is None else rtol # keep same as `derivative`
1076
1081
1077
1082
# tighten the inner tolerance to make the inner error negligible
@@ -1091,8 +1096,9 @@ def df(x):
1091
1096
nfev = [] # track inner function evaluations
1092
1097
res = jacobian (df , x , tolerances = tolerances , ** kwargs ) # jacobian of jacobian
1093
1098
1094
- nfev = np .cumsum (nfev , axis = 0 )
1095
- res .nfev = np .take_along_axis (nfev , res .nit [np .newaxis , ...], axis = 0 )[0 ]
1099
+ nfev = xp .cumulative_sum (xp .stack (nfev ), axis = 0 )
1100
+ res_nit = xp .astype (res .nit [xp .newaxis , ...], xp .int64 ) # appease torch
1101
+ res .nfev = xp_take_along_axis (nfev , res_nit , axis = 0 )[0 ]
1096
1102
res .ddf = res .df
1097
1103
del res .df # this is renamed to ddf
1098
1104
del res .nit # this is only the outer-jacobian nit
0 commit comments