diff --git a/src/solve.jl b/src/solve.jl index 55b4a636f..720e92f93 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -39,7 +39,7 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin end function init( - prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing, + prob::AbstractDEProblem, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && has_kwargs(prob) && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -215,41 +215,6 @@ function build_null_solution(prob::AbstractDEProblem, args...; build_solution(prob, nothing, ts, timeseries; dense = true, retcode) end -function build_null_solution( - prob::Union{SteadyStateProblem, NonlinearProblem}, - args...; - saveat = (), - save_everystep = true, - save_on = true, - save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, - save_end = true, - kwargs...) - prob, success = hack_null_solution_init(prob) - retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure - SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode) -end - -function build_null_solution( - prob::NonlinearLeastSquaresProblem, - args...; abstol = 1e-6, kwargs...) - prob, success = hack_null_solution_init(prob) - retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure - - if isinplace(prob) - resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype) - prob.f(resid, prob.u0, prob.p) - else - resid = prob.f(prob.f.resid_prototype, prob.p) - end - - if success - retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure - end - - SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) -end - """ ```julia solve(prob::AbstractDEProblem, alg::Union{AbstractDEAlgorithm,Nothing}; kwargs...) @@ -562,102 +527,7 @@ function solve(prob::AbstractDEProblem, args...; sensealg = nothing, end end -""" -```julia -solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...) -``` - -## Arguments - -The only positional argument is `alg` which is optional. By default, `alg = nothing`. -If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated -algorithm selection (if `using NonlinearSolve` was done, otherwise it will -error with a `MethodError`). - -## Keyword Arguments - -The NonlinearSolve.jl universe has a large set of common arguments available -for the `solve` function. These arguments apply to `solve` on any problem type and -are only limited by limitations of the specific implementations. - -Many of the defaults depend on the algorithm or the package the algorithm derives -from. Not all of the interface is provided by every algorithm. -For more detailed information on the defaults and the available options -for specific algorithms / packages, see the manual pages for the solvers of specific -problems. - -#### Error Control - -* `abstol`: Absolute tolerance. -* `reltol`: Relative tolerance. - -### Miscellaneous - -* `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. -* `verbose`: Toggles whether warnings are thrown when the solver exits early. - Defaults to true. - -### Sensitivity Algorithms (`sensealg`) - -`sensealg` is used for choosing the way the automatic differentiation is performed. - For more information, see the documentation for SciMLSensitivity: - https://docs.sciml.ai/SciMLSensitivity/stable/ -""" -function solve(prob::NonlinearProblem, args...; sensealg = nothing, - u0 = nothing, p = nothing, wrap = Val(true), kwargs...) - if sensealg === nothing && haskey(prob.kwargs, :sensealg) - sensealg = prob.kwargs[:sensealg] - end - - if haskey(prob.kwargs, :alias_u0) - @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." - alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0]) - elseif haskey(kwargs, :alias_u0) - @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." - alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0]) - end - - if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool - alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias]) - elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool - alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias]) - end - - if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier - alias_spec = prob.kwargs[:alias] - elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier - alias_spec = kwargs[:alias] - else - alias_spec = NonlinearAliasSpecifier(alias_u0 = false) - end - - alias_u0 = alias_spec.alias_u0 - - u0 = u0 !== nothing ? u0 : prob.u0 - p = p !== nothing ? p : prob.p - - if wrap isa Val{true} - wrap_sol(solve_up(prob, - sensealg, - u0, - p, - args...; - alias_u0 = alias_u0, - originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), - kwargs...)) - else - solve_up(prob, - sensealg, - u0, - p, - args...; - alias_u0 = alias_u0, - originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), - kwargs...) - end -end - -function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, +function solve_up(prob::AbstractDEProblem, sensealg, u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), kwargs...) alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) @@ -685,14 +555,6 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0 end end -function solve_call(prob::SteadyStateProblem, - alg::SciMLBase.AbstractNonlinearAlgorithm, args...; - kwargs...) - solve_call(NonlinearProblem(prob), - alg, args...; - kwargs...) -end - function solve(prob::AbstractNoiseProblem, args...; kwargs...) __solve(prob, args...; kwargs...) end @@ -705,42 +567,6 @@ function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) end -function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...) - oldprob = prob - prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) - if prob !== oldprob - kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) - end - p = get_concrete_p(prob, kwargs) - u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) - u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) -end - -function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) - oldprob = prob - prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) - if prob !== oldprob - kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) - end - p = get_concrete_p(prob, kwargs) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) -end - -function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) - oldprob = prob - prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) - if prob !== oldprob - kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) - end - p = get_concrete_p(prob, kwargs) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) -end - function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...) prob end diff --git a/test/high_level_solve.jl b/test/high_level_solve.jl index dd0739b97..88a8f7fc6 100644 --- a/test/high_level_solve.jl +++ b/test/high_level_solve.jl @@ -45,9 +45,3 @@ prob2 = DiffEqBase.get_concrete_problem(prob, true) @test prob2.tspan == (0.0, 3.0) @test prob2.constant_lags == [1.0] -prob = SteadyStateProblem((u, p, t) -> u, [1.0, 2.0]) -prob2 = DiffEqBase.get_concrete_problem(prob, true; u0 = [2.0, 3.0]) -@test prob2.u0 == [2.0, 3.0] -prob3 = DiffEqBase.get_concrete_problem(prob, true; u0 = [1.0, 3.0], p = 3.0) -@test prob3.u0 == [1.0, 3.0] -@test prob3.p == 3.0