diff --git a/lib/NonlinearSolveBase/src/jacobian.jl b/lib/NonlinearSolveBase/src/jacobian.jl index 04a61789f..b0a8106f1 100644 --- a/lib/NonlinearSolveBase/src/jacobian.jl +++ b/lib/NonlinearSolveBase/src/jacobian.jl @@ -61,7 +61,7 @@ function construct_jacobian_cache( end J = if !needs_jac - JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff) + StatefulJacobianOperator(JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff), cache.u, cache.p) else if f.jac_prototype === nothing # While this is technically wasteful, it gives out the type of the Jacobian @@ -87,7 +87,7 @@ function construct_jacobian_cache( end end - return JacobianCache(J, f, fu, u, p, stats, autodiff, di_extras) + return JacobianCache(J, f, fu, p, stats, autodiff, di_extras) end function construct_jacobian_cache( @@ -107,69 +107,62 @@ function construct_jacobian_cache( @assert !(autodiff isa AutoSparse) "`autodiff` cannot be `AutoSparse` for scalar \ nonlinear problems." di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p)) - return JacobianCache(u, f, fu, u, p, stats, autodiff, di_extras) + return JacobianCache(fu, f, fu, p, stats, autodiff, di_extras) end @concrete mutable struct JacobianCache <: AbstractJacobianCache J f <: NonlinearFunction fu - u p stats::NLStats autodiff di_extras end -function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, u0 = cache.u, kwargs...) - cache.u = u0 +function InternalAPI.reinit!(cache::JacobianCache; p = cache.p, kwargs...) cache.p = p end # Core Computation -(cache::JacobianCache)(u) = cache(cache.J, u, cache.p) -function (cache::JacobianCache{<:JacobianOperator})(::Nothing) - return StatefulJacobianOperator(cache.J, cache.u, cache.p) -end (cache::JacobianCache)(::Nothing) = cache.J - -## Operator -function (cache::JacobianCache{<:JacobianOperator})(J::JacobianOperator, u, p = cache.p) - return StatefulJacobianOperator(J, u, p) -end +(cache::JacobianCache{<:Number})(::Nothing) = cache.J ## Numbers -function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p) +function (cache::JacobianCache{<:Number})(u) cache.stats.njacs += 1 - cache.J = if SciMLBase.has_jac(cache.f) - cache.f.jac(u, p) - elseif SciMLBase.has_vjp(cache.f) - cache.f.vjp(one(u), u, p) - elseif SciMLBase.has_jvp(cache.f) - cache.f.jvp(one(u), u, p) + + (; f, J, p) = cache + cache.J = if SciMLBase.has_jac(f) + f.jac(u, p) + elseif SciMLBase.has_vjp(f) + f.vjp(one(u), u, p) + elseif SciMLBase.has_jvp(f) + f.jvp(one(u), u, p) else - DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + DI.derivative(f, cache.di_extras, cache.autodiff, u, Constant(p)) end return cache.J end ## Actually Compute the Jacobian -function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache.p) +function (cache::JacobianCache)(u) cache.stats.njacs += 1 - if SciMLBase.isinplace(cache.f) - if SciMLBase.has_jac(cache.f) - cache.f.jac(J, u, p) + (; f, J, p) = cache + if SciMLBase.isinplace(f) + if SciMLBase.has_jac(f) + f.jac(J, u, p) else DI.jacobian!( - cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p) + f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p) ) end return J else if SciMLBase.has_jac(cache.f) - cache.J = cache.f.jac(u, p) + cache.J = f.jac(u, p) else - cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + cache.J = DI.jacobian(f, cache.di_extras, cache.autodiff, u, Constant(p)) end return cache.J end diff --git a/lib/NonlinearSolveBase/src/tracing.jl b/lib/NonlinearSolveBase/src/tracing.jl index c7ae3f542..6caba9f4c 100644 --- a/lib/NonlinearSolveBase/src/tracing.jl +++ b/lib/NonlinearSolveBase/src/tracing.jl @@ -222,12 +222,13 @@ function update_trace!(cache, α = true; uses_jac_inverse = Val(false)) trace === missing && return nothing J = Utils.safe_getproperty(cache, Val(:J)) + du = SciMLBase.get_du(cache) if J === missing update_trace!( - trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α + trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, du, α ) else J = uses_jac_inverse isa Val{true} ? Utils.Pinv(cache.J) : cache.J - update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α) + update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, du, α) end end diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index db77d67e9..0505ceab5 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -57,7 +57,7 @@ SciMLBase = "2.69" SciMLJacobianOperators = "0.1.0" Setfield = "1.1.1" SparseArrays = "1.10" -SparseConnectivityTracer = "0.6.8" +SparseConnectivityTracer = "0.6.8, 1" SparseMatrixColorings = "0.4.5" StableRNGs = "1" StaticArrays = "1.9.8" diff --git a/lib/NonlinearSolveFirstOrder/src/solve.jl b/lib/NonlinearSolveFirstOrder/src/solve.jl index 56cbaac7b..5b3b64e0a 100644 --- a/lib/NonlinearSolveFirstOrder/src/solve.jl +++ b/lib/NonlinearSolveFirstOrder/src/solve.jl @@ -55,8 +55,6 @@ end u u_cache p - du # Aliased to `get_du(descent_cache)` - J # Aliased to `jac_cache.J` alg <: GeneralizedFirstOrderAlgorithm prob <: AbstractNonlinearProblem globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}} @@ -91,6 +89,13 @@ end initializealg end +function SciMLBase.get_du(cache::GeneralizedFirstOrderAlgorithmCache) + SciMLBase.get_du(cache.descent_cache) +end +function NonlinearSolveBase.set_du!(cache::GeneralizedFirstOrderAlgorithmCache, δu) + NonlinearSolveBase.set_du!(cache.descent_cache, δu) +end + function InternalAPI.reinit_self!( cache::GeneralizedFirstOrderAlgorithmCache, args...; p = cache.p, u0 = cache.u, alias_u0::Bool = hasproperty(cache, :alias_u0) ? cache.alias_u0 : false, @@ -212,7 +217,7 @@ function SciMLBase.__init( ) cache = GeneralizedFirstOrderAlgorithmCache( - fu, u, u_cache, prob.p, du, J, alg, prob, globalization, + fu, u, u_cache, prob.p, alg, prob, globalization, jac_cache, descent_cache, linesearch_cache, trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs, diff --git a/lib/NonlinearSolveQuasiNewton/src/solve.jl b/lib/NonlinearSolveQuasiNewton/src/solve.jl index f49c86df8..e24c6f45e 100644 --- a/lib/NonlinearSolveQuasiNewton/src/solve.jl +++ b/lib/NonlinearSolveQuasiNewton/src/solve.jl @@ -56,7 +56,6 @@ end u u_cache p - du # Aliased to `get_du(descent_cache)` J # Aliased to `initialization_cache.J` if !inverted_jac alg <: QuasiNewtonAlgorithm prob <: AbstractNonlinearProblem @@ -98,6 +97,13 @@ end initializealg end +function SciMLBase.get_du(cache::QuasiNewtonCache) + SciMLBase.get_du(cache.descent_cache) +end +function NonlinearSolveBase.set_du!(cache::QuasiNewtonCache, δu) + NonlinearSolveBase.set_du!(cache.descent_cache, δu) +end + function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache) NonlinearSolveBase.get_abstol(cache.termination_cache) end @@ -220,7 +226,7 @@ function SciMLBase.__init( ) cache = QuasiNewtonCache( - fu, u, u_cache, prob.p, du, J, alg, prob, globalization, + fu, u, u_cache, prob.p, J, alg, prob, globalization, initialization_cache, descent_cache, linesearch_cache, trustregion_cache, update_rule_cache, reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime, @@ -269,7 +275,7 @@ function InternalAPI.step!( elseif recompute_jacobian === nothing # Standard Step reinit = InternalAPI.solve!( - cache.reinit_rule_cache, cache.J, cache.fu, cache.u, cache.du + cache.reinit_rule_cache, cache.J, cache.fu, cache.u, SciMLBase.get_du(cache) ) reinit && (countable_reinit = true) elseif recompute_jacobian diff --git a/lib/NonlinearSolveSpectralMethods/src/solve.jl b/lib/NonlinearSolveSpectralMethods/src/solve.jl index b4bb23d43..9bfd5709c 100644 --- a/lib/NonlinearSolveSpectralMethods/src/solve.jl +++ b/lib/NonlinearSolveSpectralMethods/src/solve.jl @@ -72,6 +72,13 @@ end initializealg end +function SciMLBase.get_du(cache::GeneralizedDFSaneCache) + cache.du +end +function NonlinearSolveBase.set_du!(cache::GeneralizedDFSaneCache, δu) + cache.du = δu +end + function InternalAPI.reinit_self!( cache::GeneralizedDFSaneCache, args...; p = cache.p, u0 = cache.u, alias_u0::Bool = hasproperty(cache, :alias_u0) ? cache.alias_u0 : false,