@@ -61,7 +61,7 @@ function construct_jacobian_cache(
61
61
end
62
62
63
63
J = if ! needs_jac
64
- JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff)
64
+ StatefulJacobianOperator ( JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff), cache . u, cache . p )
65
65
else
66
66
if f. jac_prototype === nothing
67
67
# While this is technically wasteful, it gives out the type of the Jacobian
@@ -87,7 +87,7 @@ function construct_jacobian_cache(
87
87
end
88
88
end
89
89
90
- return JacobianCache (J, f, fu, u, p, stats, autodiff, di_extras)
90
+ return JacobianCache (J, f, fu, p, stats, autodiff, di_extras)
91
91
end
92
92
93
93
function construct_jacobian_cache (
@@ -107,45 +107,32 @@ function construct_jacobian_cache(
107
107
@assert ! (autodiff isa AutoSparse) " `autodiff` cannot be `AutoSparse` for scalar \
108
108
nonlinear problems."
109
109
di_extras = DI. prepare_derivative (f, autodiff, u, Constant (prob. p))
110
- return JacobianCache (fu, f, fu, u, p, stats, autodiff, di_extras)
110
+ return JacobianCache (fu, f, fu, p, stats, autodiff, di_extras)
111
111
end
112
112
113
113
@concrete mutable struct JacobianCache <: AbstractJacobianCache
114
114
J
115
115
f <: NonlinearFunction
116
116
fu
117
- u
118
117
p
119
118
stats:: NLStats
120
119
autodiff
121
120
di_extras
122
121
end
123
122
124
- function InternalAPI. reinit! (cache:: JacobianCache ; p = cache. p, u0 = cache. u, kwargs... )
125
- cache. u = u0
123
+ function InternalAPI. reinit! (cache:: JacobianCache ; p = cache. p, kwargs... )
126
124
cache. p = p
127
125
end
128
126
129
127
# Core Computation
130
- function (cache:: JacobianCache )(u)
131
- cache. u = u
132
- cache ()
133
- end
134
- function (cache:: JacobianCache{<:JacobianOperator} )(:: Nothing )
135
- return StatefulJacobianOperator (cache. J, cache. u, cache. p)
136
- end
137
128
(cache:: JacobianCache )(:: Nothing ) = cache. J
138
-
139
- # # Operator
140
- function (cache:: JacobianCache{<:JacobianOperator} )()
141
- return StatefulJacobianOperator (cache. J, cache. u, cache. p)
142
- end
129
+ (cache:: JacobianCache{<:Number} )(:: Nothing ) = cache. J
143
130
144
131
# # Numbers
145
- function (cache:: JacobianCache{<:Number} )()
132
+ function (cache:: JacobianCache{<:Number} )(u )
146
133
cache. stats. njacs += 1
147
134
148
- (; f, J, u, p) = cache
135
+ (; f, J, p) = cache
149
136
cache. J = if SciMLBase. has_jac (f)
150
137
f. jac (u, p)
151
138
elseif SciMLBase. has_vjp (f)
@@ -159,9 +146,9 @@ function (cache::JacobianCache{<:Number})()
159
146
end
160
147
161
148
# # Actually Compute the Jacobian
162
- function (cache:: JacobianCache )()
149
+ function (cache:: JacobianCache )(u )
163
150
cache. stats. njacs += 1
164
- (; f, J, u, p) = cache
151
+ (; f, J, p) = cache
165
152
if SciMLBase. isinplace (f)
166
153
if SciMLBase. has_jac (f)
167
154
f. jac (J, u, p)
0 commit comments