Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -841,35 +841,35 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
has_alg_eqs(sys) &&
error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

sts = unknowns(sys)
ps = parameters(sys)
constraintsys = get_constraintsystem(sys)

if !isnothing(constraintsys)
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k, v) in u0map]

fns = generate_function_bc(sys, u0, u0_idxs, tspan)
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
bc(sol, p, t) = bc_oop(sol, p, t)
bc(resid, u, p, t) = bc_iip(resid, u, p, t)

Expand Down
35 changes: 21 additions & 14 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ struct ODESystem <: AbstractODESystem
"""
parent::Any

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
function ODESystem(
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
Expand All @@ -214,7 +215,8 @@ struct ODESystem <: AbstractODESystem
u = __get_unit_type(dvs, ps, iv)
check_units(u, deqs)
end
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
ctrls, observed, constraints, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, assertions, metadata,
Expand Down Expand Up @@ -300,16 +302,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end

if !isempty(systems) && !isnothing(constraintsystem)
conssystems = ConstraintsSystem[]
for sys in systems
cons = get_constraintsystem(sys)
cons !== nothing && push!(conssystems, cons)
cons !== nothing && push!(conssystems, cons)
end
@show conssystems
@set! constraintsystem.systems = conssystems
end
end

assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)

Expand Down Expand Up @@ -359,9 +361,9 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
for st in get_unknowns(constraintsystem)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
end
for p in parameters(constraintsystem)
!in(p, new_ps) && push!(new_ps, p)
Expand Down Expand Up @@ -712,7 +714,8 @@ end
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
function process_constraint_system(
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
Expand All @@ -725,22 +728,26 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
# Validate the states.
for var in constraintsts
if !iscall(var)
occursin(iv, var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
occursin(iv, var) && (var ∈ sts ||
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
elseif length(arguments(var)) > 1
throw(ArgumentError("Too many arguments for variable $var."))
elseif length(arguments(var)) == 1
arg = only(arguments(var))
operation(var)(iv) ∈ sts ||
operation(var)(iv) ∈ sts ||
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))

isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
isequal(arg, iv) || isparameter(arg) || arg isa Integer ||
arg isa AbstractFloat ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && push!(constraintps, arg)
else
var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
var ∈ sts &&
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
end
end

ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
ConstraintsSystem(
constraints, collect(constraintsts), collect(constraintps); name = consname)
end
7 changes: 3 additions & 4 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,8 @@ function check_inputmap_keys(sys, u0map, pmap)
push!(badparamkeys, k)
end
end
(isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
(isempty(badvarkeys) && isempty(badparamkeys)) ||
throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
end

const BAD_KEY_MESSAGE = """
Expand All @@ -885,14 +886,12 @@ struct InvalidKeyError <: Exception
params::Any
end

function Base.showerror(io::IO, e::InvalidKeyError)
function Base.showerror(io::IO, e::InvalidKeyError)
println(io, BAD_KEY_MESSAGE)
println(io, "u0map: $(join(e.vars, ", "))")
println(io, "pmap: $(join(e.params, ", "))")
end



##############
# Legacy functions for backward compatibility
##############
Expand Down
Loading
Loading