diff --git a/lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl b/lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl index 774d22c7b..84d070ee8 100644 --- a/lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl +++ b/lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl @@ -3,38 +3,71 @@ module SCCNonlinearSolve import SciMLBase import CommonSolve import SymbolicIndexingInterface +import SciMLBase: NonlinearProblem, NonlinearLeastSquaresProblem, LinearProblem + +""" + SCCAlg(; nlalg = nothing, linalg = nothing) + +Algorithm for solving Strongly Connected Component (SCC) problems containing +both nonlinear and linear subproblems. + +### Keyword Arguments + + - `nlalg`: Algorithm to use for solving NonlinearProblem components + - `linalg`: Algorithm to use for solving LinearProblem components +""" +struct SCCAlg{N, L} + nlalg::N + linalg::L +end + +SCCAlg(; nlalg = nothing, linalg = nothing) = SCCAlg(nlalg, linalg) function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...) - CommonSolve.solve(prob, nothing; kwargs...) + CommonSolve.solve(prob, SCCAlg(nothing, nothing); kwargs...) end -function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg; kwargs...) - numscc = length(prob.probs) - sols = [SciMLBase.build_solution( - prob, nothing, prob.u0, convert(eltype(prob.u0), NaN) * prob.u0) - for prob in prob.probs] - u = reduce(vcat, [prob.u0 for prob in prob.probs]) - resid = copy(u) - - lasti = 1 - for i in 1:numscc - prob.explictfuns![i]( - SymbolicIndexingInterface.parameter_values(prob.probs[i]), sols) - sol = SciMLBase.solve(prob.probs[i], alg; kwargs...) - _sol = SciMLBase.build_solution( - prob.probs[i], nothing, sol.u, sol.resid, retcode = sol.retcode) - sols[i] = _sol - lasti = i - if !SciMLBase.successful_retcode(_sol) - break - end +function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SciMLBase.AbstractNonlinearAlgorithm; kwargs...) + CommonSolve.solve(prob, SCCAlg(alg, nothing); kwargs...) +end + +probvec(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}) = prob.u0 +probvec(prob::LinearProblem) = prob.b + +iteratively_build_sols(alg, sols; kwargs...) = sols + +function iteratively_build_sols(alg, sols, (prob, explicitfun), args...; kwargs...) + explicitfun( + SymbolicIndexingInterface.parameter_values(prob), sols) + + _sol = if prob isa SciMLBase.LinearProblem + sol = SciMLBase.solve(prob, alg.linalg; kwargs...) + SciMLBase.build_linear_solution( + alg.linalg, sol.u, nothing, nothing, retcode = sol.retcode) + else + sol = SciMLBase.solve(prob, alg.nlalg; kwargs...) + SciMLBase.build_solution( + prob, nothing, sol.u, sol.resid, retcode = sol.retcode) end + iteratively_build_sols(alg, (sols..., _sol), args...) +end + +function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...) + numscc = length(prob.probs) + sols = iteratively_build_sols( + alg, (), zip(prob.probs, prob.explicitfuns!)...; kwargs...) + # TODO: fix allocations with a lazy concatenation - u .= reduce(vcat, sols) - resid .= reduce(vcat, getproperty.(sols, :resid)) + u = reduce(vcat, sols) + resid = reduce(vcat, getproperty.(sols, :resid)) - retcode = sols[lasti].retcode + retcode = if !all(SciMLBase.successful_retcode, sols) + idx = findfirst(!SciMLBase.successful_retcode, sols) + sols[idx].retcode + else + SciMLBase.ReturnCode.Success + end SciMLBase.build_solution(prob, alg, u, resid; retcode, original = sols) end diff --git a/lib/SCCNonlinearSolve/test/core_tests.jl b/lib/SCCNonlinearSolve/test/core_tests.jl index 5bcf01fe7..06446833b 100644 --- a/lib/SCCNonlinearSolve/test/core_tests.jl +++ b/lib/SCCNonlinearSolve/test/core_tests.jl @@ -7,64 +7,71 @@ end @testitem "Manual SCC" setup=[CoreRootfindTesting] tags=[:core] begin using NonlinearSolveFirstOrder function f(du, u, p) - du[1] = cos(u[2]) - u[1] - du[2] = sin(u[1] + u[2]) + u[2] - du[3] = 2u[4] + u[3] + 1.0 - du[4] = u[5]^2 + u[4] - du[5] = u[3]^2 + u[5] - du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8] - du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8] - du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8] + du[1]=cos(u[2])-u[1] + du[2]=sin(u[1]+u[2])+u[2] + du[3]=2u[4]+u[3]+1.0 + du[4]=u[5]^2+u[4] + du[5]=u[3]^2+u[5] + du[6]=u[1]+u[2]+u[3]+u[4]+u[5]+2.0u[6]+2.5u[7]+1.5u[8] + du[7]=u[1]+u[2]+u[3]+2.0u[4]+u[5]+4.0u[6]-1.5u[7]+1.5u[8] + du[8]=u[1]+2.0u[2]+3.0u[3]+5.0u[4]+6.0u[5]+u[6]-u[7]-u[8] end - prob = NonlinearProblem(f, zeros(8)) - sol = solve(prob, NewtonRaphson()) + prob=NonlinearProblem(f, zeros(8)) + sol=solve(prob, NewtonRaphson()) - u0 = zeros(2) - p = zeros(3) + u0=zeros(2) + p=zeros(3) function f1(du, u, p) - du[1] = cos(u[2]) - u[1] - du[2] = sin(u[1] + u[2]) + u[2] + du[1]=cos(u[2])-u[1] + du[2]=sin(u[1]+u[2])+u[2] end - explicitfun1(p, sols) = nothing - prob1 = NonlinearProblem( + explicitfun1(p, sols)=nothing + prob1=NonlinearProblem( NonlinearFunction{true, SciMLBase.NoSpecialize}(f1), zeros(2), p) - sol1 = solve(prob1, NewtonRaphson()) + sol1=solve(prob1, NewtonRaphson()) function f2(du, u, p) - du[1] = 2u[2] + u[1] + 1.0 - du[2] = u[3]^2 + u[2] - du[3] = u[1]^2 + u[3] + du[1]=2u[2]+u[1]+1.0 + du[2]=u[3]^2+u[2] + du[3]=u[1]^2+u[3] end - explicitfun2(p, sols) = nothing - prob2 = NonlinearProblem( + explicitfun2(p, sols)=nothing + prob2=NonlinearProblem( NonlinearFunction{true, SciMLBase.NoSpecialize}(f2), zeros(3), p) - sol2 = solve(prob2, NewtonRaphson()) + sol2=solve(prob2, NewtonRaphson()) - function f3(du, u, p) - du[1] = p[1] + 2.0u[1] + 2.5u[2] + 1.5u[3] - du[2] = p[2] + 4.0u[1] - 1.5u[2] + 1.5u[3] - du[3] = p[3] + +u[1] - u[2] - u[3] - end - prob3 = NonlinearProblem( - NonlinearFunction{true, SciMLBase.NoSpecialize}(f3), zeros(3), p) + # Convert f3 to a LinearProblem since it's linear in u + # du = Au + b where A is the coefficient matrix and b is from parameters + A3=[2.0 2.5 1.5; 4.0 -1.5 1.5; 1.0 -1.0 -1.0] + b3=p # b will be updated by explicitfun3 + prob3=LinearProblem(A3, b3, zeros(3)) function explicitfun3(p, sols) - p[1] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3] - p[2] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3] - p[3] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] + - 6.0sols[2][3] + p[1]=-(sols[1][1]+sols[1][2]+sols[2][1]+sols[2][2]+sols[2][3]) + p[2]=-(sols[1][1]+sols[1][2]+sols[2][1]+2.0sols[2][2]+sols[2][3]) + p[3]=-(sols[1][1]+2.0sols[1][2]+3.0sols[2][1]+5.0sols[2][2]+ + 6.0sols[2][3]) end explicitfun3(p, [sol1, sol2]) - sol3 = solve(prob3, NewtonRaphson()) - manualscc = [sol1; sol2; sol3] + sol3=solve(prob3) # LinearProblem uses default linear solver + manualscc=reduce(vcat, (sol1, sol2, sol3)) - sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], + sccprob=SciMLBase.SCCNonlinearProblem((prob1, prob2, prob3), SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3])) - scc_sol = solve(sccprob, NewtonRaphson()) + + # Test with SCCAlg that handles both nonlinear and linear problems + using SCCNonlinearSolve + scc_alg=SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson(), linalg = nothing) + scc_sol=solve(sccprob, scc_alg) + @test sol ≈ manualscc ≈ scc_sol + + # Backwards compat of alg choice + scc_sol=solve(sccprob, NewtonRaphson()) @test sol ≈ manualscc ≈ scc_sol import NonlinearSolve # Required for Default - scc_sol = solve(sccprob) - @test sol ≈ manualscc ≈ scc_sol + # Test default interface + scc_sol_default=solve(sccprob) + @test sol ≈ manualscc ≈ scc_sol_default end