Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 2 additions & 176 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin
end

function init(
prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing,
prob::AbstractDEProblem, args...; sensealg = nothing,
u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && has_kwargs(prob) && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
Expand Down Expand Up @@ -215,41 +215,6 @@ function build_null_solution(prob::AbstractDEProblem, args...;
build_solution(prob, nothing, ts, timeseries; dense = true, retcode)
end

function build_null_solution(
prob::Union{SteadyStateProblem, NonlinearProblem},
args...;
saveat = (),
save_everystep = true,
save_on = true,
save_start = save_everystep || isempty(saveat) ||
saveat isa Number || prob.tspan[1] in saveat,
save_end = true,
kwargs...)
prob, success = hack_null_solution_init(prob)
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode)
end

function build_null_solution(
prob::NonlinearLeastSquaresProblem,
args...; abstol = 1e-6, kwargs...)
prob, success = hack_null_solution_init(prob)
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure

if isinplace(prob)
resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype)
prob.f(resid, prob.u0, prob.p)
else
resid = prob.f(prob.f.resid_prototype, prob.p)
end

if success
retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure
end

SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode)
end

"""
```julia
solve(prob::AbstractDEProblem, alg::Union{AbstractDEAlgorithm,Nothing}; kwargs...)
Expand Down Expand Up @@ -562,102 +527,7 @@ function solve(prob::AbstractDEProblem, args...; sensealg = nothing,
end
end

"""
```julia
solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...)
```

## Arguments

The only positional argument is `alg` which is optional. By default, `alg = nothing`.
If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated
algorithm selection (if `using NonlinearSolve` was done, otherwise it will
error with a `MethodError`).

## Keyword Arguments

The NonlinearSolve.jl universe has a large set of common arguments available
for the `solve` function. These arguments apply to `solve` on any problem type and
are only limited by limitations of the specific implementations.

Many of the defaults depend on the algorithm or the package the algorithm derives
from. Not all of the interface is provided by every algorithm.
For more detailed information on the defaults and the available options
for specific algorithms / packages, see the manual pages for the solvers of specific
problems.

#### Error Control

* `abstol`: Absolute tolerance.
* `reltol`: Relative tolerance.

### Miscellaneous

* `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5.
* `verbose`: Toggles whether warnings are thrown when the solver exits early.
Defaults to true.

### Sensitivity Algorithms (`sensealg`)

`sensealg` is used for choosing the way the automatic differentiation is performed.
For more information, see the documentation for SciMLSensitivity:
https://docs.sciml.ai/SciMLSensitivity/stable/
"""
function solve(prob::NonlinearProblem, args...; sensealg = nothing,
u0 = nothing, p = nothing, wrap = Val(true), kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end

if haskey(prob.kwargs, :alias_u0)
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
elseif haskey(kwargs, :alias_u0)
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0])
end

if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool
alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool
alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias])
end

if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
alias_spec = prob.kwargs[:alias]
elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier
alias_spec = kwargs[:alias]
else
alias_spec = NonlinearAliasSpecifier(alias_u0 = false)
end

alias_u0 = alias_spec.alias_u0

u0 = u0 !== nothing ? u0 : prob.u0
p = p !== nothing ? p : prob.p

if wrap isa Val{true}
wrap_sol(solve_up(prob,
sensealg,
u0,
p,
args...;
alias_u0 = alias_u0,
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
kwargs...))
else
solve_up(prob,
sensealg,
u0,
p,
args...;
alias_u0 = alias_u0,
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
kwargs...)
end
end

function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
function solve_up(prob::AbstractDEProblem, sensealg, u0, p,
args...; originator = SciMLBase.ChainRulesOriginator(),
kwargs...)
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
Expand Down Expand Up @@ -685,14 +555,6 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0
end
end

function solve_call(prob::SteadyStateProblem,
alg::SciMLBase.AbstractNonlinearAlgorithm, args...;
kwargs...)
solve_call(NonlinearProblem(prob),
alg, args...;
kwargs...)
end

function solve(prob::AbstractNoiseProblem, args...; kwargs...)
__solve(prob, args...; kwargs...)
end
Expand All @@ -705,42 +567,6 @@ function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...)
get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...)
end

function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...)
oldprob = prob
prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...)
if prob !== oldprob
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
end
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, Inf, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)
end

function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
oldprob = prob
prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...)
if prob !== oldprob
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
end
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)
end

function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...)
oldprob = prob
prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...)
if prob !== oldprob
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
end
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)
end

function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...)
prob
end
Expand Down
6 changes: 0 additions & 6 deletions test/high_level_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,3 @@ prob2 = DiffEqBase.get_concrete_problem(prob, true)
@test prob2.tspan == (0.0, 3.0)
@test prob2.constant_lags == [1.0]

prob = SteadyStateProblem((u, p, t) -> u, [1.0, 2.0])
prob2 = DiffEqBase.get_concrete_problem(prob, true; u0 = [2.0, 3.0])
@test prob2.u0 == [2.0, 3.0]
prob3 = DiffEqBase.get_concrete_problem(prob, true; u0 = [1.0, 3.0], p = 3.0)
@test prob3.u0 == [1.0, 3.0]
@test prob3.p == 3.0
Loading