Skip to content

Commit 3a7dc69

Browse files
committed
start removing aliasing
1 parent b37b31b commit 3a7dc69

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function construct_jacobian_cache(
107107
@assert !(autodiff isa AutoSparse) "`autodiff` cannot be `AutoSparse` for scalar \
108108
nonlinear problems."
109109
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
110-
return JacobianCache(u, f, fu, u, p, stats, autodiff, di_extras)
110+
return JacobianCache(fu, f, fu, u, p, stats, autodiff, di_extras)
111111
end
112112

113113
@concrete mutable struct JacobianCache <: AbstractJacobianCache
@@ -127,49 +127,55 @@ function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, u0 = cache.u, kw
127127
end
128128

129129
# Core Computation
130-
(cache::JacobianCache)(u) = cache(cache.J, u, cache.p)
130+
function (cache::JacobianCache)(u)
131+
cache.u = u
132+
cache()
133+
end
131134
function (cache::JacobianCache{<:JacobianOperator})(::Nothing)
132135
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
133136
end
134137
(cache::JacobianCache)(::Nothing) = cache.J
135138

136139
## Operator
137-
function (cache::JacobianCache{<:JacobianOperator})(J::JacobianOperator, u, p = cache.p)
138-
return StatefulJacobianOperator(J, u, p)
140+
function (cache::JacobianCache{<:JacobianOperator})()
141+
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
139142
end
140143

141144
## Numbers
142-
function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p)
145+
function (cache::JacobianCache{<:Number})()
143146
cache.stats.njacs += 1
144-
cache.J = if SciMLBase.has_jac(cache.f)
145-
cache.f.jac(u, p)
146-
elseif SciMLBase.has_vjp(cache.f)
147-
cache.f.vjp(one(u), u, p)
148-
elseif SciMLBase.has_jvp(cache.f)
149-
cache.f.jvp(one(u), u, p)
147+
148+
(; f, J, u, p) = cache
149+
cache.J = if SciMLBase.has_jac(f)
150+
f.jac(u, p)
151+
elseif SciMLBase.has_vjp(f)
152+
f.vjp(one(u), u, p)
153+
elseif SciMLBase.has_jvp(f)
154+
f.jvp(one(u), u, p)
150155
else
151-
DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
156+
DI.derivative(f, cache.di_extras, cache.autodiff, u, Constant(p))
152157
end
153158
return cache.J
154159
end
155160

156161
## Actually Compute the Jacobian
157-
function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache.p)
162+
function (cache::JacobianCache)()
158163
cache.stats.njacs += 1
159-
if SciMLBase.isinplace(cache.f)
160-
if SciMLBase.has_jac(cache.f)
161-
cache.f.jac(J, u, p)
164+
(; f, J, u, p) = cache
165+
if SciMLBase.isinplace(f)
166+
if SciMLBase.has_jac(f)
167+
f.jac(J, u, p)
162168
else
163169
DI.jacobian!(
164-
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
170+
f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
165171
)
166172
end
167173
return J
168174
else
169175
if SciMLBase.has_jac(cache.f)
170-
cache.J = cache.f.jac(u, p)
176+
cache.J = f.jac(u, p)
171177
else
172-
cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
178+
cache.J = DI.jacobian(f, cache.di_extras, cache.autodiff, u, Constant(p))
173179
end
174180
return cache.J
175181
end

lib/NonlinearSolveBase/src/tracing.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,13 @@ function update_trace!(cache, α = true; uses_jac_inverse = Val(false))
222222
trace === missing && return nothing
223223

224224
J = Utils.safe_getproperty(cache, Val(:J))
225+
du = SciMLBase.get_du(cache)
225226
if J === missing
226227
update_trace!(
227-
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α
228+
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, du, α
228229
)
229230
else
230231
J = uses_jac_inverse isa Val{true} ? Utils.Pinv(cache.J) : cache.J
231-
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α)
232+
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, du, α)
232233
end
233234
end

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ end
5555
u
5656
u_cache
5757
p
58-
du # Aliased to `get_du(descent_cache)`
59-
J # Aliased to `jac_cache.J`
6058
alg <: GeneralizedFirstOrderAlgorithm
6159
prob <: AbstractNonlinearProblem
6260
globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}}
@@ -91,6 +89,13 @@ end
9189
initializealg
9290
end
9391

92+
function SciMLBase.get_du(cache::GeneralizedFirstOrderAlgorithmCache)
93+
SciMLBase.get_du(cache.descent_cache)
94+
end
95+
function NonlinearSolveBase.set_du!(cache::GeneralizedFirstOrderAlgorithmCache, δu)
96+
NonlinearSolveBase.set_du!(cache.descent_cache, δu)
97+
end
98+
9499
function InternalAPI.reinit_self!(
95100
cache::GeneralizedFirstOrderAlgorithmCache, args...; p = cache.p, u0 = cache.u,
96101
alias_u0::Bool = hasproperty(cache, :alias_u0) ? cache.alias_u0 : false,
@@ -212,7 +217,7 @@ function SciMLBase.__init(
212217
)
213218

214219
cache = GeneralizedFirstOrderAlgorithmCache(
215-
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
220+
fu, u, u_cache, prob.p, alg, prob, globalization,
216221
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
217222
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
218223
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,

lib/NonlinearSolveQuasiNewton/src/solve.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ end
5656
u
5757
u_cache
5858
p
59-
du # Aliased to `get_du(descent_cache)`
6059
J # Aliased to `initialization_cache.J` if !inverted_jac
6160
alg <: QuasiNewtonAlgorithm
6261
prob <: AbstractNonlinearProblem
@@ -98,6 +97,13 @@ end
9897
initializealg
9998
end
10099

100+
function SciMLBase.get_du(cache::QuasiNewtonCache)
101+
SciMLBase.get_du(cache.descent_cache)
102+
end
103+
function NonlinearSolveBase.set_du!(cache::QuasiNewtonCache, δu)
104+
NonlinearSolveBase.set_du!(cache.descent_cache, δu)
105+
end
106+
101107
function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache)
102108
NonlinearSolveBase.get_abstol(cache.termination_cache)
103109
end
@@ -220,7 +226,7 @@ function SciMLBase.__init(
220226
)
221227

222228
cache = QuasiNewtonCache(
223-
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
229+
fu, u, u_cache, prob.p, J, alg, prob, globalization,
224230
initialization_cache, descent_cache, linesearch_cache,
225231
trustregion_cache, update_rule_cache, reinit_rule_cache,
226232
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime,
@@ -269,7 +275,7 @@ function InternalAPI.step!(
269275
elseif recompute_jacobian === nothing
270276
# Standard Step
271277
reinit = InternalAPI.solve!(
272-
cache.reinit_rule_cache, cache.J, cache.fu, cache.u, cache.du
278+
cache.reinit_rule_cache, cache.J, cache.fu, cache.u, get_du(cache)
273279
)
274280
reinit && (countable_reinit = true)
275281
elseif recompute_jacobian

0 commit comments

Comments
 (0)