Skip to content

Commit 7a17463

Browse files
Merge pull request #1548 from SciML/myb/promote
Promote to concrete type by default if it's an array
2 parents e8630bc + 4f3e003 commit 7a17463

File tree

13 files changed

+92
-50
lines changed

13 files changed

+92
-50
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ julia = "1.6"
8282
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8383
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8484
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
85+
GalacticOptimJL = "9d3c5eb1-403b-401b-8c0f-c11105342e6b"
8586
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
8687
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8788
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -93,4 +94,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9394
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9495

9596
[targets]
96-
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
97+
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "GalacticOptimJL", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/structural_transformation/codegen.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,9 @@ function ODAEProblem{iip}(
521521
sys,
522522
u0map,
523523
tspan,
524-
parammap=DiffEqBase.NullParameters();
524+
parammap = DiffEqBase.NullParameters();
525525
callback = nothing,
526+
use_union = false,
526527
kwargs...
527528
) where {iip}
528529
fun, dvs = build_torn_function(sys; kwargs...)
@@ -531,8 +532,8 @@ function ODAEProblem{iip}(
531532

532533
defs = ModelingToolkit.mergedefaults(defs,parammap,ps)
533534
defs = ModelingToolkit.mergedefaults(defs,u0map,dvs)
534-
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
535-
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
535+
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
536+
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
536537

537538
has_difference = any(isdifferenceeq, equations(sys))
538539
if has_continuous_events(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
599599
simplify=false,
600600
linenumbers = true, parallel=SerialForm(),
601601
eval_expression = true,
602+
use_union = false,
602603
kwargs...)
603604
eqs = equations(sys)
604605
dvs = states(sys)
@@ -609,12 +610,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
609610
defs = mergedefaults(defs,parammap,ps)
610611
defs = mergedefaults(defs,u0map,dvs)
611612

612-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
613-
p = varmap_to_vars(parammap,ps; defaults=defs)
613+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
614+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
614615
if implicit_dae && du0map !== nothing
615616
ddvs = map(Differential(iv), dvs)
616617
defs = mergedefaults(defs,du0map, ddvs)
617-
du0 = varmap_to_vars(du0map,ddvs; defaults=defs, toterm=identity, promotetoconcrete=true)
618+
du0 = varmap_to_vars(du0map,ddvs; defaults=defs, toterm=identity, tofloat=true)
618619
else
619620
du0 = nothing
620621
ddvs = nothing

src/systems/discrete_system/discrete_system.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,20 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
169169
parammap=DiffEqBase.NullParameters();
170170
eval_module = @__MODULE__,
171171
eval_expression = true,
172+
use_union = false,
172173
kwargs...)
173174
dvs = states(sys)
174175
ps = parameters(sys)
175176
eqs = equations(sys)
176177
eqs = linearize_eqs(sys, eqs)
177178
iv = get_iv(sys)
178-
179+
179180
defs = defaults(sys)
180181
defs = mergedefaults(defs,parammap,ps)
181182
defs = mergedefaults(defs,u0map,dvs)
182-
183-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
184-
p = varmap_to_vars(parammap,ps; defaults=defs)
183+
184+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
185+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
185186

186187
rhss = [eq.rhs for eq in eqs]
187188
u = dvs

src/systems/jumps/jumpsystem.jl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ end
202202
"""
203203
```julia
204204
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
205-
parammap=DiffEqBase.NullParameters; kwargs...)
205+
parammap=DiffEqBase.NullParameters;
206+
use_union=false,
207+
kwargs...)
206208
```
207209
208210
Generates a blank DiscreteProblem for a pure jump JumpSystem to utilize as
@@ -219,20 +221,22 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
219221
```
220222
"""
221223
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
222-
parammap=DiffEqBase.NullParameters(); checkbounds=false, kwargs...)
223-
224+
parammap=DiffEqBase.NullParameters(); checkbounds=false,
225+
use_union=false,
226+
kwargs...)
227+
224228
dvs = states(sys)
225229
ps = parameters(sys)
226-
230+
227231
defs = defaults(sys)
228232
defs = mergedefaults(defs,parammap,ps)
229-
defs = mergedefaults(defs,u0map,dvs)
230-
231-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
232-
p = varmap_to_vars(parammap,ps; defaults=defs)
233-
233+
defs = mergedefaults(defs,u0map,dvs)
234+
235+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
236+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
237+
234238
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
235-
239+
236240
# just taken from abstractodesystem.jl for ODEFunction def
237241
obs = observed(sys)
238242
observedfun = let sys = sys, dict = Dict()
@@ -268,10 +272,15 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
268272
```
269273
"""
270274
function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
271-
parammap=DiffEqBase.NullParameters(); kwargs...)
275+
parammap=DiffEqBase.NullParameters();
276+
use_union=false,
277+
kwargs...)
278+
dvs = states(sys)
279+
ps = parameters(sys)
272280
defs = defaults(sys)
273-
u0 = varmap_to_vars(u0map, states(sys); defaults=defs)
274-
p = varmap_to_vars(parammap, parameters(sys); defaults=defs)
281+
282+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
283+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
275284
# identity function to make syms works
276285
quote
277286
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,18 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,paramm
260260
simplify=false,
261261
linenumbers = true, parallel=SerialForm(),
262262
eval_expression = true,
263+
use_union = false,
263264
kwargs...)
264265
eqs = equations(sys)
265266
dvs = states(sys)
266267
ps = parameters(sys)
267-
268+
268269
defs = defaults(sys)
269270
defs = mergedefaults(defs,parammap,ps)
270271
defs = mergedefaults(defs,u0map,dvs)
271-
272-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
273-
p = varmap_to_vars(parammap,ps; defaults=defs)
272+
273+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
274+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
274275

275276
check_eqs_u0(eqs, dvs, u0; kwargs...)
276277

src/systems/optimization/optimizationsystem.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
149149
hess = false, sparse = false,
150150
checkbounds = false,
151151
linenumbers = true, parallel=SerialForm(),
152+
use_union = false,
152153
kwargs...) where iip
153154
dvs = states(sys)
154155
ps = parameters(sys)
@@ -180,10 +181,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
180181
defs = mergedefaults(defs,parammap,ps)
181182
defs = mergedefaults(defs,u0map,dvs)
182183

183-
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
184-
p = varmap_to_vars(parammap,ps; defaults=defs)
185-
lb = varmap_to_vars(lb,dvs; check=false)
186-
ub = varmap_to_vars(ub,dvs; check=false)
184+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
185+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
186+
lb = varmap_to_vars(lb, dvs; check=false, tofloat=false, use_union)
187+
ub = varmap_to_vars(ub, dvs; check=false, tofloat=false, use_union)
187188
OptimizationProblem{iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
188189
end
189190

@@ -215,6 +216,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
215216
hess = false, sparse = false,
216217
checkbounds = false,
217218
linenumbers = false, parallel=SerialForm(),
219+
use_union = false,
218220
kwargs...) where iip
219221
dvs = states(sys)
220222
ps = parameters(sys)
@@ -239,10 +241,10 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
239241
defs = mergedefaults(defs,parammap,ps)
240242
defs = mergedefaults(defs,u0map,dvs)
241243

242-
u0 = varmap_to_vars(u0map,dvs; defaults=defs, promotetoconcrete=true)
243-
p = varmap_to_vars(parammap,ps; defaults=defs)
244-
lb = varmap_to_vars(lb,dvs)
245-
ub = varmap_to_vars(ub,dvs)
244+
u0 = varmap_to_vars(u0map, dvs; defaults=defs, tofloat=false)
245+
p = varmap_to_vars(parammap, ps; defaults=defs, tofloat=false, use_union)
246+
lb = varmap_to_vars(lb, dvs; check=false, tofloat=false, use_union)
247+
ub = varmap_to_vars(ub, dvs; check=false, tofloat=false, use_union)
246248
quote
247249
f = $f
248250
p = $p

src/utils.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,31 @@ function mergedefaults(defaults, varmap, vars)
472472
end
473473
end
474474

475-
function promote_to_concrete(vs)
476-
if isempty(vs)
475+
function promote_to_concrete(vs; tofloat=true, use_union=false)
476+
if isempty(vs)
477477
return vs
478478
end
479479
T = eltype(vs)
480-
if Base.isconcretetype(T) # nothing to do
480+
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
481481
vs
482482
else
483-
C = foldl((t, elem)->promote_type(t, eltype(elem)), vs; init=typeof(first(vs)))
483+
C = typeof(first(vs))
484+
has_int = false
485+
I = Int8
486+
for v in vs
487+
E = eltype(v)
488+
C = promote_type(C, E)
489+
if E <: Integer
490+
has_int = true
491+
I = promote_type(I, E)
492+
end
493+
end
494+
if tofloat
495+
C = float(C)
496+
elseif use_union && has_int && C !== I
497+
C = Union{C, I}
498+
return copyto!(similar(vs, C), vs)
499+
end
484500
convert.(C, vs)
485501
end
486-
end
502+
end

src/variables.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Takes a list of pairs of `variables=>values` and an ordered list of variables
3232
and creates the array of values in the correct order with default values when
3333
applicable.
3434
"""
35-
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=false)
35+
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=nothing, tofloat=true, use_union=false)
3636
varlist = map(unwrap, varlist)
3737
# Edge cases where one of the arguments is effectively empty.
3838
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters || varmap === nothing
@@ -58,16 +58,16 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Sym
5858
varmap
5959
end
6060

61+
promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
6162
if promotetoconcrete
62-
vals = promote_to_concrete(vals)
63+
vals = promote_to_concrete(vals; tofloat=tofloat, use_union=use_union)
6364
end
6465

6566
if isempty(vals)
6667
return nothing
6768
elseif container_type <: Tuple
6869
(vals...,)
6970
else
70-
vals = identity.(vals)
7171
SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(), Val(length(vals)), vals...)
7272
end
7373
end
@@ -79,7 +79,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults=Dict(), check=false, to
7979
for (p, v) in pairs(varmap)
8080
varmap[p] = fixpoint_sub(v, varmap)
8181
end
82-
82+
8383
missingvars = setdiff(varlist, keys(varmap))
8484
check && (isempty(missingvars) || throw_missingvars(missingvars))
8585

test/controlsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, GalacticOptim, Optim
1+
using ModelingToolkit, GalacticOptim, Optim, GalacticOptimJL
22

33
@variables t x(t) v(t) u(t)
44
@parameters p[1:2]

0 commit comments

Comments
 (0)