Skip to content

Commit 736520a

Browse files
authored
Merge pull request #2231 from SciML/myb/ps
Handle inhomogeneous parameters using a Tuple of Vectors
2 parents 5ef23af + 6acee61 commit 736520a

File tree

9 files changed

+250
-36
lines changed

9 files changed

+250
-36
lines changed

src/parameters.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,49 @@ macro parameters(xs...)
6161
xs,
6262
toparam) |> esc
6363
end
64+
65+
function find_types(array)
66+
by = let set = Dict{Any, Int}(), counter = Ref(0)
67+
x -> begin
68+
# t = typeof(x)
69+
70+
get!(set, typeof(x)) do
71+
# if t == Float64
72+
# 1
73+
# else
74+
counter[] += 1
75+
# end
76+
end
77+
end
78+
end
79+
return by.(array)
80+
end
81+
82+
function split_parameters_by_type(ps)
83+
if ps === SciMLBase.NullParameters()
84+
return Float64[], [] #use Float64 to avoid Any type warning
85+
else
86+
by = let set = Dict{Any, Int}(), counter = Ref(0)
87+
x -> begin
88+
get!(set, typeof(x)) do
89+
counter[] += 1
90+
end
91+
end
92+
end
93+
idxs = by.(ps)
94+
split_idxs = [Int[]]
95+
for (i, idx) in enumerate(idxs)
96+
if idx > length(split_idxs)
97+
push!(split_idxs, Int[])
98+
end
99+
push!(split_idxs[idx], i)
100+
end
101+
tighten_types = x -> identity.(x)
102+
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
103+
if length(split_ps) == 1 #Tuple not needed, only 1 type
104+
return split_ps[1], split_idxs
105+
else
106+
return (split_ps...,), split_idxs
107+
end
108+
end
109+
end

src/structural_transformation/codegen.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys,
528528
tspan,
529529
parammap = DiffEqBase.NullParameters();
530530
callback = nothing,
531-
use_union = false,
531+
use_union = true,
532+
tofloat = true,
532533
check = true,
533534
kwargs...) where {iip}
534535
eqs = equations(sys)
@@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys,
540541
defs = ModelingToolkit.mergedefaults(defs, parammap, ps)
541542
defs = ModelingToolkit.mergedefaults(defs, u0map, dvs)
542543
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
543-
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union,
544-
use_union)
544+
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
545545

546546
has_difference = any(isdifferenceeq, eqs)
547547
cbs = process_events(sys; callback, has_difference, kwargs...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
152152
states = sol_states,
153153
kwargs...)
154154
else
155-
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156-
kwargs...)
155+
if p isa Tuple
156+
build_function(rhss, u, p..., t; postprocess_fbody = pre,
157+
states = sol_states,
158+
kwargs...)
159+
else
160+
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
161+
kwargs...)
162+
end
157163
end
158164
end
159165
end
@@ -332,8 +338,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
332338
f_oop, f_iip = eval_expression ?
333339
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
334340
f_gen
335-
f(u, p, t) = f_oop(u, p, t)
336-
f(du, u, p, t) = f_iip(du, u, p, t)
341+
if p isa Tuple
342+
g(u, p, t) = f_oop(u, p..., t)
343+
g(du, u, p, t) = f_iip(du, u, p..., t)
344+
f = g
345+
else
346+
k(u, p, t) = f_oop(u, p, t)
347+
k(du, u, p, t) = f_iip(du, u, p, t)
348+
f = k
349+
end
337350

338351
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
339352
if u0 === nothing || p === nothing || t === nothing
@@ -384,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
384397

385398
obs = observed(sys)
386399
observedfun = if steady_state
387-
let sys = sys, dict = Dict()
400+
let sys = sys, dict = Dict(), ps = ps
388401
function generated_observed(obsvar, args...)
389402
obs = get!(dict, value(obsvar)) do
390403
build_explicit_observed_function(sys, obsvar)
391404
end
392405
if args === ()
393406
let obs = obs
394-
(u, p, t = Inf) -> obs(u, p, t)
407+
(u, p, t = Inf) -> if ps isa Tuple
408+
obs(u, p..., t)
409+
else
410+
obs(u, p, t)
411+
end
395412
end
396413
else
397-
length(args) == 2 ? obs(args..., Inf) : obs(args...)
414+
if ps isa Tuple
415+
if length(args) == 2
416+
u, p = args
417+
obs(u, p..., Inf)
418+
else
419+
u, p, t = args
420+
obs(u, p..., t)
421+
end
422+
else
423+
if length(args) == 2
424+
u, p = args
425+
obs(u, p, Inf)
426+
else
427+
u, p, t = args
428+
obs(u, p, t)
429+
end
430+
end
398431
end
399432
end
400433
end
401434
else
402-
let sys = sys, dict = Dict()
435+
let sys = sys, dict = Dict(), ps = ps
403436
function generated_observed(obsvar, args...)
404437
obs = get!(dict, value(obsvar)) do
405-
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
438+
build_explicit_observed_function(sys,
439+
obsvar;
440+
checkbounds = checkbounds,
441+
ps)
406442
end
407443
if args === ()
408444
let obs = obs
409-
(u, p, t) -> obs(u, p, t)
445+
(u, p, t) -> if ps isa Tuple
446+
obs(u, p..., t)
447+
else
448+
obs(u, p, t)
449+
end
410450
end
411451
else
412-
obs(args...)
452+
if ps isa Tuple # split parameters
453+
u, p, t = args
454+
obs(u, p..., t)
455+
else
456+
obs(args...)
457+
end
413458
end
414459
end
415460
end
@@ -677,15 +722,15 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
677722
end
678723

679724
"""
680-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
725+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
681726
682727
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
683728
"""
684729
function get_u0_p(sys,
685730
u0map,
686731
parammap;
687-
use_union = false,
688-
tofloat = !use_union,
732+
use_union = true,
733+
tofloat = true,
689734
symbolic_u0 = false)
690735
dvs = states(sys)
691736
ps = parameters(sys)
@@ -712,16 +757,27 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
712757
simplify = false,
713758
linenumbers = true, parallel = SerialForm(),
714759
eval_expression = true,
715-
use_union = false,
716-
tofloat = !use_union,
760+
use_union = true,
761+
tofloat = true,
717762
symbolic_u0 = false,
718763
kwargs...)
719764
eqs = equations(sys)
720765
dvs = states(sys)
721766
ps = parameters(sys)
722767
iv = get_iv(sys)
723768

724-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
769+
u0, p, defs = get_u0_p(sys,
770+
u0map,
771+
parammap;
772+
tofloat,
773+
use_union,
774+
symbolic_u0)
775+
776+
p, split_idxs = split_parameters_by_type(p)
777+
if p isa Tuple
778+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
779+
ps = (ps...,) #if p is Tuple, ps should be Tuple
780+
end
725781

726782
if implicit_dae && du0map !== nothing
727783
ddvs = map(Differential(iv), dvs)
@@ -738,7 +794,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
738794
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
739795
checkbounds = checkbounds, p = p,
740796
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
741-
sparse = sparse, eval_expression = eval_expression, kwargs...)
797+
sparse = sparse, eval_expression = eval_expression,
798+
kwargs...)
742799
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
743800
end
744801

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts;
314314
output_type = Array,
315315
checkbounds = true,
316316
drop_expr = drop_expr,
317+
ps = parameters(sys),
317318
throw = true)
318319
if (isscalar = !(ts isa AbstractVector))
319320
ts = [ts]
@@ -385,17 +386,20 @@ function build_explicit_observed_function(sys, ts;
385386
push!(obsexprs, lhs rhs)
386387
end
387388

388-
pars = parameters(sys)
389389
if inputs !== nothing
390-
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
390+
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
391+
end
392+
if ps isa Tuple
393+
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
394+
else
395+
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
391396
end
392-
ps = DestructuredArgs(pars, inbounds = !checkbounds)
393397
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
394398
if inputs === nothing
395-
args = [dvs, ps, ivs...]
399+
args = [dvs, ps..., ivs...]
396400
else
397401
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
398-
args = [dvs, ipts, ps, ivs...]
402+
args = [dvs, ipts, ps..., ivs...]
399403
end
400404
pre = get_postprocess_fbody(sys)
401405

src/utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ end
219219

220220
hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
221221
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
222+
function getdefaulttype(v)
223+
def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing))
224+
def === nothing ? Float64 : typeof(def)
225+
end
222226
function setdefault(v, val)
223227
val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
224228
end
@@ -642,10 +646,15 @@ end
642646
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
643647
end
644648

645-
function promote_to_concrete(vs; tofloat = true, use_union = false)
649+
function promote_to_concrete(vs; tofloat = true, use_union = true)
646650
if isempty(vs)
647651
return vs
648652
end
653+
if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array
654+
tofloat = false
655+
use_union = true
656+
vs = Any[vs...]
657+
end
649658
T = eltype(vs)
650659
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
651660
vs
@@ -656,6 +665,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
656665
I = Int8
657666
has_int = false
658667
has_array = false
668+
has_bool = false
659669
array_T = nothing
660670
for v in vs
661671
if v isa AbstractArray
@@ -668,6 +678,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
668678
has_int = true
669679
I = promote_type(I, E)
670680
end
681+
if E <: Bool
682+
has_bool = true
683+
end
671684
end
672685
if tofloat && !has_array
673686
C = float(C)
@@ -678,6 +691,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
678691
if has_int
679692
C = Union{C, I}
680693
end
694+
if has_bool
695+
C = Union{C, Bool}
696+
end
681697
return copyto!(similar(vs, C), vs)
682698
end
683699
convert.(C, vs)

src/variables.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ applicable.
5858
"""
5959
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
6060
toterm = default_toterm, promotetoconcrete = nothing,
61-
tofloat = true, use_union = false)
61+
tofloat = true, use_union = true)
6262
varlist = collect(map(unwrap, varlist))
6363

6464
# Edge cases where one of the arguments is effectively empty.
@@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
7575
end
7676
end
7777

78-
T = typeof(varmap)
79-
# We respect the input type
80-
container_type = T <: Dict ? Array : T
78+
# T = typeof(varmap)
79+
# We respect the input type (feature removed, not needed with Tuple support)
80+
# container_type = T <: Union{Dict,Tuple} ? Array : T
81+
container_type = Array
8182

8283
vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
8384
varmap = todict(varmap)

test/odesystem.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -734,18 +734,28 @@ let
734734
u0map = [A => 1.0]
735735
pmap = (k1 => 1.0, k2 => 1)
736736
tspan = (0.0, 1.0)
737+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
738+
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
739+
737740
prob = ODEProblem(sys, u0map, tspan, pmap)
738-
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
741+
@test prob.p isa Vector{Float64}
739742

740743
pmap = [k1 => 1, k2 => 1]
741744
tspan = (0.0, 1.0)
742745
prob = ODEProblem(sys, u0map, tspan, pmap)
743746
@test eltype(prob.p) === Float64
744747

745-
pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
746-
tspan = (0.0, 1.0)
747-
prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
748-
@test eltype(prob.p) === Union{Float64, Int}
748+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
749+
@test eltype(prob.p) === Int
750+
751+
prob = ODEProblem(sys, u0map, tspan, pmap)
752+
@test prob.p isa Vector{Float64}
753+
754+
# No longer supported, Tuple used instead
755+
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
756+
# tspan = (0.0, 1.0)
757+
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
758+
# @test eltype(prob.p) === Union{Float64, Int}
749759
end
750760

751761
let

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using SafeTestsets, Test
2323
@safetestset "JumpSystem Test" include("jumpsystem.jl")
2424
@safetestset "Constraints Test" include("constraints.jl")
2525
@safetestset "Reduction Test" include("reduction.jl")
26+
@safetestset "Split Parameters Test" include("split_parameters.jl")
2627
@safetestset "ODAEProblem Test" include("odaeproblem.jl")
2728
@safetestset "Components Test" include("components.jl")
2829
@safetestset "Model Parsing Test" include("model_parsing.jl")

0 commit comments

Comments
 (0)