Skip to content

Commit cab11eb

Browse files
feat: validate parameter type and allow dependent initial values in param init
1 parent 015ce60 commit cab11eb

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,16 @@ struct InitializationSystemMetadata
174174
end
175175

176176
function is_parameter_solvable(p, pmap, defs, guesses)
177+
p = unwrap(p)
178+
is_variable_floatingpoint(p) || return false
177179
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
178180
_val2 = get(defs, p, nothing)
179181
_val3 = get(guesses, p, nothing)
180182
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
181183
# the ODEProblem and it has a default and a guess)
182184
return ((_val1 === missing || _val2 === missing) ||
183-
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
185+
(symbolic_type(_val1) != NotSymbolic() ||
186+
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
184187
end
185188

186189
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,10 @@ end
885885

886886
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
887887
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
888+
889+
function is_variable_floatingpoint(sym)
890+
sym = unwrap(sym)
891+
T = symtype(sym)
892+
return T == Real || T <: AbstractFloat || T <: AbstractArray{Real} ||
893+
T <: AbstractArray{<:AbstractFloat}
894+
end

test/initializationsystem.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,19 +647,37 @@ end
647647
prob2.ps[p] = 0.0
648648
test_parameter(prob2, p, 2.0)
649649

650-
# Should not be solved for:
650+
# Default overridden by ODEProblem, guess provided
651+
@mtkbuild sys = ODESystem(
652+
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
653+
_pmap = merge(pmap, Dict(p => q))
654+
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
655+
test_parameter(prob, p, _pmap[q])
656+
test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p)
651657

652-
# ODEProblem value with guess, no `missing`
658+
# ODEProblem dependent value with guess, no `missing`
653659
@mtkbuild sys = ODESystem([D(x) ~ x * q, D(y) ~ y * p], t; guesses = [p => 0.0])
654660
_pmap = merge(pmap, Dict(p => 3q))
655661
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
656-
@test prob.ps[p] 3.0
657-
@test prob.f.initializeprob === nothing
658-
# Default overridden by ODEProblem, guess provided
662+
test_parameter(prob, p, 3pmap[q])
663+
664+
# Should not be solved for:
665+
666+
# Override dependent default with direct value
659667
@mtkbuild sys = ODESystem(
660668
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
669+
_pmap = merge(pmap, Dict(p => 1.0))
661670
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
662-
@test prob.ps[p] 3.0
671+
@test prob.ps[p] 1.0
672+
@test prob.f.initializeprob === nothing
673+
674+
# Non-floating point
675+
@parameters r::Int s::Int
676+
@mtkbuild sys = ODESystem(
677+
[D(x) ~ s * x, D(y) ~ y * r], t; defaults = [s => 2r], guesses = [s => 1.0])
678+
prob = ODEProblem(sys, u0map, (0.0, 1.0), [r => 1])
679+
@test prob.ps[r] == 1
680+
@test prob.ps[s] == 2
663681
@test prob.f.initializeprob === nothing
664682

665683
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t; guesses = [p => 0.0])

0 commit comments

Comments
 (0)