Skip to content

Commit 1e3fb7c

Browse files
committed
fix jacobian
1 parent cb65057 commit 1e3fb7c

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function construct_jacobian_cache(
6161
end
6262

6363
J = if !needs_jac
64-
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
64+
StatefulJacobianOperator(JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff), cache.u, cache.p)
6565
else
6666
if f.jac_prototype === nothing
6767
# While this is technically wasteful, it gives out the type of the Jacobian
@@ -87,7 +87,7 @@ function construct_jacobian_cache(
8787
end
8888
end
8989

90-
return JacobianCache(J, f, fu, u, p, stats, autodiff, di_extras)
90+
return JacobianCache(J, f, fu, p, stats, autodiff, di_extras)
9191
end
9292

9393
function construct_jacobian_cache(
@@ -107,45 +107,32 @@ function construct_jacobian_cache(
107107
@assert !(autodiff isa AutoSparse) "`autodiff` cannot be `AutoSparse` for scalar \
108108
nonlinear problems."
109109
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
110-
return JacobianCache(fu, f, fu, u, p, stats, autodiff, di_extras)
110+
return JacobianCache(fu, f, fu, p, stats, autodiff, di_extras)
111111
end
112112

113113
@concrete mutable struct JacobianCache <: AbstractJacobianCache
114114
J
115115
f <: NonlinearFunction
116116
fu
117-
u
118117
p
119118
stats::NLStats
120119
autodiff
121120
di_extras
122121
end
123122

124-
function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, u0 = cache.u, kwargs...)
125-
cache.u = u0
123+
function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, kwargs...)
126124
cache.p = p
127125
end
128126

129127
# Core Computation
130-
function (cache::JacobianCache)(u)
131-
cache.u = u
132-
cache()
133-
end
134-
function (cache::JacobianCache{<:JacobianOperator})(::Nothing)
135-
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
136-
end
137128
(cache::JacobianCache)(::Nothing) = cache.J
138-
139-
## Operator
140-
function (cache::JacobianCache{<:JacobianOperator})()
141-
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
142-
end
129+
(cache::JacobianCache{<:Number})(::Nothing) = cache.J
143130

144131
## Numbers
145-
function (cache::JacobianCache{<:Number})()
132+
function (cache::JacobianCache{<:Number})(u)
146133
cache.stats.njacs += 1
147134

148-
(; f, J, u, p) = cache
135+
(; f, J, p) = cache
149136
cache.J = if SciMLBase.has_jac(f)
150137
f.jac(u, p)
151138
elseif SciMLBase.has_vjp(f)
@@ -159,9 +146,9 @@ function (cache::JacobianCache{<:Number})()
159146
end
160147

161148
## Actually Compute the Jacobian
162-
function (cache::JacobianCache)()
149+
function (cache::JacobianCache)(u)
163150
cache.stats.njacs += 1
164-
(; f, J, u, p) = cache
151+
(; f, J, p) = cache
165152
if SciMLBase.isinplace(f)
166153
if SciMLBase.has_jac(f)
167154
f.jac(J, u, p)

0 commit comments

Comments
 (0)