Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions lib/NonlinearSolveBase/src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function construct_jacobian_cache(
end

J = if !needs_jac
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
StatefulJacobianOperator(JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff), cache.u, cache.p)
else
if f.jac_prototype === nothing
# While this is technically wasteful, it gives out the type of the Jacobian
Expand All @@ -87,7 +87,7 @@ function construct_jacobian_cache(
end
end

return JacobianCache(J, f, fu, u, p, stats, autodiff, di_extras)
return JacobianCache(J, f, fu, p, stats, autodiff, di_extras)
end

function construct_jacobian_cache(
Expand All @@ -107,69 +107,62 @@ function construct_jacobian_cache(
@assert !(autodiff isa AutoSparse) "`autodiff` cannot be `AutoSparse` for scalar \
nonlinear problems."
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
return JacobianCache(u, f, fu, u, p, stats, autodiff, di_extras)
return JacobianCache(fu, f, fu, p, stats, autodiff, di_extras)
end

@concrete mutable struct JacobianCache <: AbstractJacobianCache
J
f <: NonlinearFunction
fu
u
p
stats::NLStats
autodiff
di_extras
end

function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, u0 = cache.u, kwargs...)
cache.u = u0
function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, kwargs...)
cache.p = p
end

# Core Computation
(cache::JacobianCache)(u) = cache(cache.J, u, cache.p)
function (cache::JacobianCache{<:JacobianOperator})(::Nothing)
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
end
(cache::JacobianCache)(::Nothing) = cache.J

## Operator
function (cache::JacobianCache{<:JacobianOperator})(J::JacobianOperator, u, p = cache.p)
return StatefulJacobianOperator(J, u, p)
end
(cache::JacobianCache{<:Number})(::Nothing) = cache.J

## Numbers
function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p)
function (cache::JacobianCache{<:Number})(u)
cache.stats.njacs += 1
cache.J = if SciMLBase.has_jac(cache.f)
cache.f.jac(u, p)
elseif SciMLBase.has_vjp(cache.f)
cache.f.vjp(one(u), u, p)
elseif SciMLBase.has_jvp(cache.f)
cache.f.jvp(one(u), u, p)

(; f, J, p) = cache
cache.J = if SciMLBase.has_jac(f)
f.jac(u, p)
elseif SciMLBase.has_vjp(f)
f.vjp(one(u), u, p)
elseif SciMLBase.has_jvp(f)
f.jvp(one(u), u, p)
else
DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
DI.derivative(f, cache.di_extras, cache.autodiff, u, Constant(p))
end
return cache.J
end

## Actually Compute the Jacobian
function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache.p)
function (cache::JacobianCache)(u)
cache.stats.njacs += 1
if SciMLBase.isinplace(cache.f)
if SciMLBase.has_jac(cache.f)
cache.f.jac(J, u, p)
(; f, J, p) = cache
if SciMLBase.isinplace(f)
if SciMLBase.has_jac(f)
f.jac(J, u, p)
else
DI.jacobian!(
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
)
end
return J
else
if SciMLBase.has_jac(cache.f)
cache.J = cache.f.jac(u, p)
cache.J = f.jac(u, p)
else
cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
cache.J = DI.jacobian(f, cache.di_extras, cache.autodiff, u, Constant(p))
end
return cache.J
end
Expand Down
5 changes: 3 additions & 2 deletions lib/NonlinearSolveBase/src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,13 @@ function update_trace!(cache, α = true; uses_jac_inverse = Val(false))
trace === missing && return nothing

J = Utils.safe_getproperty(cache, Val(:J))
du = SciMLBase.get_du(cache)
if J === missing
update_trace!(
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, du, α
)
else
J = uses_jac_inverse isa Val{true} ? Utils.Pinv(cache.J) : cache.J
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α)
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, du, α)
end
end
2 changes: 1 addition & 1 deletion lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ SciMLBase = "2.69"
SciMLJacobianOperators = "0.1.0"
Setfield = "1.1.1"
SparseArrays = "1.10"
SparseConnectivityTracer = "0.6.8"
SparseConnectivityTracer = "0.6.8, 1"
SparseMatrixColorings = "0.4.5"
StableRNGs = "1"
StaticArrays = "1.9.8"
Expand Down
11 changes: 8 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ end
u
u_cache
p
du # Aliased to `get_du(descent_cache)`
J # Aliased to `jac_cache.J`
alg <: GeneralizedFirstOrderAlgorithm
prob <: AbstractNonlinearProblem
globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}}
Expand Down Expand Up @@ -91,6 +89,13 @@ end
initializealg
end

function SciMLBase.get_du(cache::GeneralizedFirstOrderAlgorithmCache)
SciMLBase.get_du(cache.descent_cache)
end
function NonlinearSolveBase.set_du!(cache::GeneralizedFirstOrderAlgorithmCache, δu)
NonlinearSolveBase.set_du!(cache.descent_cache, δu)
end

function InternalAPI.reinit_self!(
cache::GeneralizedFirstOrderAlgorithmCache, args...; p = cache.p, u0 = cache.u,
alias_u0::Bool = hasproperty(cache, :alias_u0) ? cache.alias_u0 : false,
Expand Down Expand Up @@ -212,7 +217,7 @@ function SciMLBase.__init(
)

cache = GeneralizedFirstOrderAlgorithmCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
fu, u, u_cache, prob.p, alg, prob, globalization,
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,
Expand Down
12 changes: 9 additions & 3 deletions lib/NonlinearSolveQuasiNewton/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ end
u
u_cache
p
du # Aliased to `get_du(descent_cache)`
J # Aliased to `initialization_cache.J` if !inverted_jac
alg <: QuasiNewtonAlgorithm
prob <: AbstractNonlinearProblem
Expand Down Expand Up @@ -98,6 +97,13 @@ end
initializealg
end

function SciMLBase.get_du(cache::QuasiNewtonCache)
SciMLBase.get_du(cache.descent_cache)
end
function NonlinearSolveBase.set_du!(cache::QuasiNewtonCache, δu)
NonlinearSolveBase.set_du!(cache.descent_cache, δu)
end

function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_abstol(cache.termination_cache)
end
Expand Down Expand Up @@ -220,7 +226,7 @@ function SciMLBase.__init(
)

cache = QuasiNewtonCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
fu, u, u_cache, prob.p, J, alg, prob, globalization,
initialization_cache, descent_cache, linesearch_cache,
trustregion_cache, update_rule_cache, reinit_rule_cache,
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime,
Expand Down Expand Up @@ -269,7 +275,7 @@ function InternalAPI.step!(
elseif recompute_jacobian === nothing
# Standard Step
reinit = InternalAPI.solve!(
cache.reinit_rule_cache, cache.J, cache.fu, cache.u, cache.du
cache.reinit_rule_cache, cache.J, cache.fu, cache.u, SciMLBase.get_du(cache)
)
reinit && (countable_reinit = true)
elseif recompute_jacobian
Expand Down
7 changes: 7 additions & 0 deletions lib/NonlinearSolveSpectralMethods/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ end
initializealg
end

function SciMLBase.get_du(cache::GeneralizedDFSaneCache)
cache.du
end
function NonlinearSolveBase.set_du!(cache::GeneralizedDFSaneCache, δu)
cache.du = δu
end

function InternalAPI.reinit_self!(
cache::GeneralizedDFSaneCache, args...; p = cache.p, u0 = cache.u,
alias_u0::Bool = hasproperty(cache, :alias_u0) ? cache.alias_u0 : false,
Expand Down
Loading