Skip to content

Commit 10d95e4

Browse files
feat: validate parameter type and allow dependent initial values in param init
1 parent 0ea8894 commit 10d95e4

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,16 @@ struct InitializationSystemMetadata
188188
end
189189

190190
function is_parameter_solvable(p, pmap, defs, guesses)
191+
p = unwrap(p)
192+
is_variable_floatingpoint(p) || return false
191193
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
192194
_val2 = get(defs, p, nothing)
193195
_val3 = get(guesses, p, nothing)
194196
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
195197
# the ODEProblem and it has a default and a guess)
196198
return ((_val1 === missing || _val2 === missing) ||
197-
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
199+
(symbolic_type(_val1) != NotSymbolic() ||
200+
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
198201
end
199202

200203
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,9 @@ end
885885

886886
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
887887
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
888+
889+
function is_variable_floatingpoint(sym)
890+
sym = unwrap(sym)
891+
T = symtype(sym)
892+
return T == Real || T <: AbstractFloat || T <: AbstractArray{Real} || T <: AbstractArray{<:AbstractFloat}
893+
end

test/initializationsystem.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -637,21 +637,39 @@ sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success
637637
prob2.ps[p] = 0.0
638638
test_parameter(prob2, p, 2.0)
639639

640-
# Should not be solved for:
640+
# Default overridden by ODEProblem, guess provided
641+
@mtkbuild sys = ODESystem(
642+
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
643+
_pmap = merge(pmap, Dict(p => q))
644+
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
645+
test_parameter(prob, p, pmap[q])
646+
test_initializesystem(sys, u0map, pmap, p, 0 ~ q - p)
641647

642-
# ODEProblem value with guess, no `missing`
648+
# ODEProblem dependent value with guess, no `missing`
643649
@mtkbuild sys = ODESystem([D(x) ~ x * q, D(y) ~ y * p], t; guesses = [p => 0.0])
644650
_pmap = merge(pmap, Dict(p => 3q))
645651
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
646-
@test prob.ps[p] 3.0
647-
@test prob.f.initializeprob === nothing
648-
# Default overridden by ODEProblem, guess provided
652+
test_parameter(prob, p, 3pmap[q])
653+
654+
# Should not be solved for:
655+
656+
# Override dependent default with direct value
649657
@mtkbuild sys = ODESystem(
650658
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
659+
_pmap = merge(pmap, Dict(p => 1.0))
651660
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
652-
@test prob.ps[p] 3.0
661+
@test prob.ps[p] 1.0
653662
@test prob.f.initializeprob === nothing
654663

664+
# Non-floating point
665+
@parameters r::Int s::Int
666+
@mtkbuild sys = ODESystem(
667+
[D(x) ~ s * x, D(y) ~ y * r], t; defaults = [s => 2r], guesses = [s => 1.0])
668+
prob = ODEProblem(sys, u0map, (0.0, 1.0), [r => 1])
669+
@test prob.ps[r] == 1
670+
@test prob.ps[s] == 2
671+
@test prob.f.initializeprob === nothing
672+
655673
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t; guesses = [p => 0.0])
656674
@test_throws ModelingToolkit.MissingParametersError ODEProblem(
657675
sys, [x => 1.0, y => 1.0], (0.0, 1.0))

0 commit comments

Comments
 (0)