Skip to content

Commit 10ccf8a

Browse files
author
Brad Carman
committed
format
1 parent d20f68d commit 10ccf8a

File tree

6 files changed

+36
-41
lines changed

6 files changed

+36
-41
lines changed

src/parameters.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function find_types(array)
6666
by = let set = Dict{Any, Int}(), counter = Ref(0)
6767
x -> begin
6868
# t = typeof(x)
69-
69+
7070
get!(set, typeof(x)) do
7171
# if t == Float64
7272
# 1
@@ -79,16 +79,14 @@ function find_types(array)
7979
return by.(array)
8080
end
8181

82-
8382
function split_parameters_by_type(ps)
84-
8583
if ps === SciMLBase.NullParameters()
86-
return Float64[],[] #use Float64 to avoid Any type warning
84+
return Float64[], [] #use Float64 to avoid Any type warning
8785
else
8886
by = let set = Dict{Any, Int}(), counter = Ref(0)
89-
x -> begin
87+
x -> begin
9088
get!(set, typeof(x)) do
91-
counter[] += 1
89+
counter[] += 1
9290
end
9391
end
9492
end
@@ -103,7 +101,7 @@ function split_parameters_by_type(ps)
103101
tighten_types = x -> identity.(x)
104102
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
105103
if length(split_ps) == 1 #Tuple not needed, only 1 type
106-
return split_ps[1], split_idxs
104+
return split_ps[1], split_idxs
107105
else
108106
return (split_ps...,), split_idxs
109107
end

src/structural_transformation/codegen.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys,
528528
tspan,
529529
parammap = DiffEqBase.NullParameters();
530530
callback = nothing,
531-
use_union = false,
531+
use_union = true,
532+
tofloat = true,
532533
check = true,
533534
kwargs...) where {iip}
534535
eqs = equations(sys)
@@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys,
540541
defs = ModelingToolkit.mergedefaults(defs, parammap, ps)
541542
defs = ModelingToolkit.mergedefaults(defs, u0map, dvs)
542543
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
543-
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union,
544-
use_union)
544+
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
545545

546546
has_difference = any(isdifferenceeq, eqs)
547547
cbs = process_events(sys; callback, has_difference, kwargs...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,12 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
153153
kwargs...)
154154
else
155155
if p isa Tuple
156-
build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states,
156+
build_function(rhss, u, p..., t; postprocess_fbody = pre,
157+
states = sol_states,
157158
kwargs...)
158159
else
159160
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
160-
kwargs...)
161+
kwargs...)
161162
end
162163
end
163164
end
@@ -332,7 +333,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
332333
analytic = nothing,
333334
kwargs...) where {iip, specialize}
334335
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
335-
expression_module = eval_module, checkbounds = checkbounds,
336+
expression_module = eval_module, checkbounds = checkbounds,
336337
kwargs...)
337338
f_oop, f_iip = eval_expression ?
338339
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
@@ -689,7 +690,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
689690
end
690691

691692
"""
692-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
693+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
693694
694695
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
695696
"""
@@ -743,11 +744,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
743744
symbolic_u0)
744745

745746
# if split_parameters
746-
p, split_idxs = split_parameters_by_type(p)
747-
if p isa Tuple
748-
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
749-
ps = (ps...,) #if p is Tuple, ps should be Tuple
750-
end
747+
p, split_idxs = split_parameters_by_type(p)
748+
if p isa Tuple
749+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
750+
ps = (ps...,) #if p is Tuple, ps should be Tuple
751+
end
751752
# end
752753

753754
if implicit_dae && du0map !== nothing
@@ -765,7 +766,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
765766
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
766767
checkbounds = checkbounds, p = p,
767768
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
768-
sparse = sparse, eval_expression = eval_expression,
769+
sparse = sparse, eval_expression = eval_expression,
769770
kwargs...)
770771
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
771772
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ end
670670
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
671671
end
672672

673-
function promote_to_concrete(vs; tofloat = true, use_union = false)
673+
function promote_to_concrete(vs; tofloat = true, use_union = true)
674674
if isempty(vs)
675675
return vs
676676
end

test/odesystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,18 +734,23 @@ let
734734
u0map = [A => 1.0]
735735
pmap = (k1 => 1.0, k2 => 1)
736736
tspan = (0.0, 1.0)
737-
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false)
738-
737+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
739738
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
740739

740+
prob = ODEProblem(sys, u0map, tspan, pmap)
741+
@test prob.p isa Vector{Float64}
742+
741743
pmap = [k1 => 1, k2 => 1]
742744
tspan = (0.0, 1.0)
743745
prob = ODEProblem(sys, u0map, tspan, pmap)
744746
@test eltype(prob.p) === Float64
745-
746-
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false)
747+
748+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
747749
@test eltype(prob.p) === Int
748750

751+
prob = ODEProblem(sys, u0map, tspan, pmap)
752+
@test prob.p isa Vector{Float64}
753+
749754
# No longer supported, Tuple used instead
750755
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
751756
# tspan = (0.0, 1.0)

test/split_parameters.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
44

5-
6-
75
# ------------------------ Mixed Single Values and Vector
86

97
dt = 4e-4
@@ -14,11 +12,10 @@ x = @. time^2 + 1.0
1412
@parameters t
1513
D = Differential(t)
1614

17-
get_value(data, t, dt) = data[round(Int, t/dt+1)]
15+
get_value(data, t, dt) = data[round(Int, t / dt + 1)]
1816
@register_symbolic get_value(data, t, dt)
1917

20-
21-
function Sampled(; name, data=Float64[], dt=0.0)
18+
function Sampled(; name, data = Float64[], dt = 0.0)
2219
pars = @parameters begin
2320
data = data
2421
dt = dt
@@ -30,7 +27,7 @@ function Sampled(; name, data=Float64[], dt=0.0)
3027
end
3128

3229
eqs = [
33-
output.u ~ get_value(data, t, dt)
30+
output.u ~ get_value(data, t, dt),
3431
]
3532

3633
return ODESystem(eqs, t, vars, pars; name, systems,
@@ -54,15 +51,13 @@ prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x])
5451
sol = solve(prob, ImplicitEuler());
5552
@test sol.retcode == ReturnCode.Success
5653

57-
5854
# ------------------------ Mixed Type Converted to float (default behavior)
5955

6056
vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
6157
pars = @parameters a=1.0 b=2.0 c=3
62-
eqs = [
63-
D(y) ~ dy*a
64-
D(dy) ~ ddy*b
65-
ddy ~ sin(t)*c]
58+
eqs = [D(y) ~ dy * a
59+
D(dy) ~ ddy * b
60+
ddy ~ sin(t) * c]
6661

6762
@named sys = ODESystem(eqs, t, vars, pars)
6863
sys = structural_simplify(sys)
@@ -74,14 +69,10 @@ prob = ODEProblem(sys, [], tspan, [])
7469
sol = solve(prob, ImplicitEuler());
7570
@test sol.retcode == ReturnCode.Success
7671

77-
7872
# ------------------------ Mixed Type Conserved
7973

80-
prob = ODEProblem(sys, [], tspan, []; tofloat=false)
74+
prob = ODEProblem(sys, [], tspan, []; tofloat = false)
8175

8276
@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}}
8377
sol = solve(prob, ImplicitEuler());
8478
@test sol.retcode == ReturnCode.Success
85-
86-
87-

0 commit comments

Comments
 (0)