@@ -61,7 +61,7 @@ function construct_jacobian_cache(
6161 end
6262
6363 J = if ! needs_jac
64- JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff)
64+ StatefulJacobianOperator ( JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff), cache . u, cache . p )
6565 else
6666 if f. jac_prototype === nothing
6767 # While this is technically wasteful, it gives out the type of the Jacobian
@@ -87,7 +87,7 @@ function construct_jacobian_cache(
8787 end
8888 end
8989
90- return JacobianCache (J, f, fu, u, p, stats, autodiff, di_extras)
90+ return JacobianCache (J, f, fu, p, stats, autodiff, di_extras)
9191end
9292
9393function construct_jacobian_cache (
@@ -107,45 +107,32 @@ function construct_jacobian_cache(
107107 @assert ! (autodiff isa AutoSparse) " `autodiff` cannot be `AutoSparse` for scalar \
108108 nonlinear problems."
109109 di_extras = DI. prepare_derivative (f, autodiff, u, Constant (prob. p))
110- return JacobianCache (fu, f, fu, u, p, stats, autodiff, di_extras)
110+ return JacobianCache (fu, f, fu, p, stats, autodiff, di_extras)
111111end
112112
113113@concrete mutable struct JacobianCache <: AbstractJacobianCache
114114 J
115115 f <: NonlinearFunction
116116 fu
117- u
118117 p
119118 stats:: NLStats
120119 autodiff
121120 di_extras
122121end
123122
124- function InternalAPI. reinit! (cache:: JacobianCache ; p = cache. p, u0 = cache. u, kwargs... )
125- cache. u = u0
123+ function InternalAPI. reinit! (cache:: JacobianCache ; p = cache. p, kwargs... )
126124 cache. p = p
127125end
128126
129127# Core Computation
130- function (cache:: JacobianCache )(u)
131- cache. u = u
132- cache ()
133- end
134- function (cache:: JacobianCache{<:JacobianOperator} )(:: Nothing )
135- return StatefulJacobianOperator (cache. J, cache. u, cache. p)
136- end
137128(cache:: JacobianCache )(:: Nothing ) = cache. J
138-
139- # # Operator
140- function (cache:: JacobianCache{<:JacobianOperator} )()
141- return StatefulJacobianOperator (cache. J, cache. u, cache. p)
142- end
129+ (cache:: JacobianCache{<:Number} )(:: Nothing ) = cache. J
143130
144131# # Numbers
145- function (cache:: JacobianCache{<:Number} )()
132+ function (cache:: JacobianCache{<:Number} )(u )
146133 cache. stats. njacs += 1
147134
148- (; f, J, u, p) = cache
135+ (; f, J, p) = cache
149136 cache. J = if SciMLBase. has_jac (f)
150137 f. jac (u, p)
151138 elseif SciMLBase. has_vjp (f)
@@ -159,9 +146,9 @@ function (cache::JacobianCache{<:Number})()
159146end
160147
161148# # Actually Compute the Jacobian
162- function (cache:: JacobianCache )()
149+ function (cache:: JacobianCache )(u )
163150 cache. stats. njacs += 1
164- (; f, J, u, p) = cache
151+ (; f, J, p) = cache
165152 if SciMLBase. isinplace (f)
166153 if SciMLBase. has_jac (f)
167154 f. jac (J, u, p)
0 commit comments