Skip to content

Commit b92b9ee

Browse files
committed
Add use_union option
1 parent eb06581 commit b92b9ee

File tree

8 files changed

+68
-42
lines changed

8 files changed

+68
-42
lines changed

src/structural_transformation/codegen.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,9 @@ function ODAEProblem{iip}(
521521
sys,
522522
u0map,
523523
tspan,
524-
parammap=DiffEqBase.NullParameters();
524+
parammap = DiffEqBase.NullParameters();
525525
callback = nothing,
526+
use_union = false,
526527
kwargs...
527528
) where {iip}
528529
fun, dvs = build_torn_function(sys; kwargs...)
@@ -531,8 +532,8 @@ function ODAEProblem{iip}(
531532

532533
defs = ModelingToolkit.mergedefaults(defs,parammap,ps)
533534
defs = ModelingToolkit.mergedefaults(defs,u0map,dvs)
534-
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
535-
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
535+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
536+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
536537

537538
has_difference = any(isdifferenceeq, equations(sys))
538539
if has_continuous_events(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
599599
simplify=false,
600600
linenumbers = true, parallel=SerialForm(),
601601
eval_expression = true,
602+
use_union = false,
602603
kwargs...)
603604
eqs = equations(sys)
604605
dvs = states(sys)
@@ -609,12 +610,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
609610
defs = mergedefaults(defs,parammap,ps)
610611
defs = mergedefaults(defs,u0map,dvs)
611612

612-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
613-
p = varmap_to_vars(parammap,ps; defaults=defs)
613+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
614+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
614615
if implicit_dae && du0map !== nothing
615616
ddvs = map(Differential(iv), dvs)
616617
defs = mergedefaults(defs,du0map, ddvs)
617-
du0 = varmap_to_vars(du0map,ddvs; defaults=defs, toterm=identity, promotetoconcrete=true)
618+
du0 = varmap_to_vars(du0map,ddvs; defaults=defs, toterm=identity, tofloat=true)
618619
else
619620
du0 = nothing
620621
ddvs = nothing

src/systems/discrete_system/discrete_system.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,20 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
169169
parammap=DiffEqBase.NullParameters();
170170
eval_module = @__MODULE__,
171171
eval_expression = true,
172+
use_union = false,
172173
kwargs...)
173174
dvs = states(sys)
174175
ps = parameters(sys)
175176
eqs = equations(sys)
176177
eqs = linearize_eqs(sys, eqs)
177178
iv = get_iv(sys)
178-
179+
179180
defs = defaults(sys)
180181
defs = mergedefaults(defs,parammap,ps)
181182
defs = mergedefaults(defs,u0map,dvs)
182-
183-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
184-
p = varmap_to_vars(parammap,ps; defaults=defs)
183+
184+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
185+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
185186

186187
rhss = [eq.rhs for eq in eqs]
187188
u = dvs

src/systems/jumps/jumpsystem.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ end
202202
"""
203203
```julia
204204
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
205-
parammap=DiffEqBase.NullParameters; kwargs...)
205+
parammap=DiffEqBase.NullParameters;
206+
use_union=false,
207+
kwargs...)
206208
```
207209
208210
Generates a blank DiscreteProblem for a pure jump JumpSystem to utilize as
@@ -219,20 +221,22 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
219221
```
220222
"""
221223
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
222-
parammap=DiffEqBase.NullParameters(); checkbounds=false, kwargs...)
223-
224+
parammap=DiffEqBase.NullParameters(); checkbounds=false,
225+
use_union=false,
226+
kwargs...)
227+
224228
dvs = states(sys)
225229
ps = parameters(sys)
226-
230+
227231
defs = defaults(sys)
228232
defs = mergedefaults(defs,parammap,ps)
229-
defs = mergedefaults(defs,u0map,dvs)
230-
231-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
232-
p = varmap_to_vars(parammap,ps; defaults=defs)
233-
233+
defs = mergedefaults(defs,u0map,dvs)
234+
235+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
236+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
237+
234238
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
235-
239+
236240
# just taken from abstractodesystem.jl for ODEFunction def
237241
obs = observed(sys)
238242
observedfun = let sys = sys, dict = Dict()
@@ -268,10 +272,12 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
268272
```
269273
"""
270274
function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
271-
parammap=DiffEqBase.NullParameters(); kwargs...)
275+
parammap=DiffEqBase.NullParameters();
276+
use_union=false,
277+
kwargs...)
272278
defs = defaults(sys)
273-
u0 = varmap_to_vars(u0map, states(sys); defaults=defs)
274-
p = varmap_to_vars(parammap, parameters(sys); defaults=defs)
279+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
280+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
275281
# identity function to make syms works
276282
quote
277283
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,18 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,paramm
260260
simplify=false,
261261
linenumbers = true, parallel=SerialForm(),
262262
eval_expression = true,
263+
use_union = false,
263264
kwargs...)
264265
eqs = equations(sys)
265266
dvs = states(sys)
266267
ps = parameters(sys)
267-
268+
268269
defs = defaults(sys)
269270
defs = mergedefaults(defs,parammap,ps)
270271
defs = mergedefaults(defs,u0map,dvs)
271-
272-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
273-
p = varmap_to_vars(parammap,ps; defaults=defs)
272+
273+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
274+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
274275

275276
check_eqs_u0(eqs, dvs, u0; kwargs...)
276277

src/systems/optimization/optimizationsystem.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
180180
defs = mergedefaults(defs,parammap,ps)
181181
defs = mergedefaults(defs,u0map,dvs)
182182

183-
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
184-
p = varmap_to_vars(parammap,ps; defaults=defs)
185-
lb = varmap_to_vars(lb,dvs; check=false)
186-
ub = varmap_to_vars(ub,dvs; check=false)
183+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
184+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
185+
lb = varmap_to_vars(lb, dvs; check=false, tofloat=false, use_union)
186+
ub = varmap_to_vars(ub, dvs; check=false, tofloat=false, use_union)
187187
OptimizationProblem{iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
188188
end
189189

@@ -215,6 +215,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
215215
hess = false, sparse = false,
216216
checkbounds = false,
217217
linenumbers = false, parallel=SerialForm(),
218+
use_union = false,
218219
kwargs...) where iip
219220
dvs = states(sys)
220221
ps = parameters(sys)
@@ -239,10 +240,10 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
239240
defs = mergedefaults(defs,parammap,ps)
240241
defs = mergedefaults(defs,u0map,dvs)
241242

242-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
243-
p = varmap_to_vars(parammap,ps; defaults=defs)
244-
lb = varmap_to_vars(lb,dvs)
245-
ub = varmap_to_vars(ub,dvs)
243+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
244+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
245+
lb = varmap_to_vars(lb, dvs; check=false, tofloat=false, use_union)
246+
ub = varmap_to_vars(ub, dvs; check=false, tofloat=false, use_union)
246247
quote
247248
f = $f
248249
p = $p

src/utils.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,31 @@ function mergedefaults(defaults, varmap, vars)
472472
end
473473
end
474474

475-
function promote_to_concrete(vs, tofloat=true)
475+
function promote_to_concrete(vs; tofloat=true, use_union=false)
476476
if isempty(vs)
477477
return vs
478478
end
479479
T = eltype(vs)
480-
if Base.isconcretetype(T) # nothing to do
480+
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
481481
vs
482482
else
483-
C = foldl((t, elem)->promote_type(t, eltype(elem)), vs; init=typeof(first(vs)))
484-
convert.(tofloat ? float(C) : C, vs)
483+
C = typeof(first(vs))
484+
has_int = false
485+
I = Int8
486+
for v in vs
487+
E = eltype(v)
488+
C = promote_type(C, E)
489+
if E <: Integer
490+
has_int = true
491+
I = promote_type(I, E)
492+
end
493+
end
494+
if tofloat
495+
C = float(C)
496+
elseif use_union && has_int && C !== I
497+
C = Union{C, I}
498+
return copyto!(similar(vs, C), vs)
499+
end
500+
convert.(C, vs)
485501
end
486502
end

src/variables.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Takes a list of pairs of `variables=>values` and an ordered list of variables
3232
and creates the array of values in the correct order with default values when
3333
applicable.
3434
"""
35-
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=nothing)
35+
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=nothing, tofloat=true, use_union=false)
3636
varlist = map(unwrap, varlist)
3737
# Edge cases where one of the arguments is effectively empty.
3838
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters || varmap === nothing
@@ -60,15 +60,14 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Sym
6060

6161
promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
6262
if promotetoconcrete
63-
vals = promote_to_concrete(vals)
63+
vals = promote_to_concrete(vals; tofloat=tofloat, use_union=use_union)
6464
end
6565

6666
if isempty(vals)
6767
return nothing
6868
elseif container_type <: Tuple
6969
(vals...,)
7070
else
71-
vals = identity.(vals)
7271
SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(), Val(length(vals)), vals...)
7372
end
7473
end

0 commit comments

Comments
 (0)