Skip to content

Commit ae4e6f7

Browse files
committed
add tests
1 parent 417b386 commit ae4e6f7

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

src/systems/problem_utils.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -687,11 +687,11 @@ function process_SciMLProblem(
687687

688688
u0map = to_varmap(u0map, dvs)
689689
symbols_to_symbolics!(sys, u0map)
690-
check_keys(sys, u0map)
691690

692691
pmap = to_varmap(pmap, ps)
693692
symbols_to_symbolics!(sys, pmap)
694-
check_keys(sys, pmap)
693+
694+
check_inputmap_keys(sys, u0map, pmap)
695695

696696
defs = add_toterms(recursive_unwrap(defaults(sys)))
697697
cmap, cs = get_cmap(sys)
@@ -783,29 +783,37 @@ end
783783

784784
# Check that the keys of a u0map or pmap are valid
785785
# (i.e. are symbolic keys, and are defined for the system.)
786-
function check_keys(sys, map)
787-
badkeys = Any[]
788-
for k in keys(map)
786+
function check_inputmap_keys(sys, u0map, pmap)
787+
badvarkeys = Any[]
788+
for k in keys(u0map)
789789
if symbolic_type(k) === NotSymbolic()
790-
push!(badkeys, k)
790+
push!(badvarkeys, k)
791791
end
792792
end
793793

794-
isempty(badkeys) || throw(BadKeyError(collect(badkeys)))
794+
badparamkeys = Any[]
795+
for k in keys(pmap)
796+
if symbolic_type(k) === NotSymbolic()
797+
push!(badparamkeys, k)
798+
end
799+
end
800+
(isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
795801
end
796802

797803
const BAD_KEY_MESSAGE = """
798-
Undefined keys found in the parameter or initial condition maps.
799-
The following keys are either invalid or not parameters/states of the system:
804+
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
805+
The following keys are invalid:
800806
"""
801807

802-
struct BadKeyError <: Exception
808+
struct InvalidKeyError <: Exception
803809
vars::Any
810+
params::Any
804811
end
805812

806-
function Base.showerror(io::IO, e::BadKeyError)
813+
function Base.showerror(io::IO, e::InvalidKeyError)
807814
println(io, BAD_KEY_MESSAGE)
808-
println(io, join(e.vars, ", "))
815+
println(io, "u0map: $(join(e.vars, ", "))")
816+
println(io, "pmap: $(join(e.params, ", "))")
809817
end
810818

811819

test/odesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,3 +1626,18 @@ end
16261626
prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
16271627
@test prob.u0 isa SVector
16281628
end
1629+
1630+
@testset "input map validation" begin
1631+
import ModelingToolkit: InvalidKeyError
1632+
@variables x(t) y(t) z(t)
1633+
@parameters a b c d
1634+
eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
1635+
@mtkbuild sys = ODESystem(eqs, t)
1636+
pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
1637+
u0map = [x => 1, y => 2, z => 3]
1638+
@test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
1639+
1640+
pmap = [a => 1, b => 2, c => 3, d => 4]
1641+
u0map = [x => 1, y => 2, z => 3, :0 => 3]
1642+
@test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
1643+
end

0 commit comments

Comments
 (0)