Skip to content

Commit 95b3e40

Browse files
committed
Promote to concrete type by default if it's an array
1 parent cdcf181 commit 95b3e40

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,15 @@ function mergedefaults(defaults, varmap, vars)
472472
end
473473
end
474474

475-
function promote_to_concrete(vs)
476-
if isempty(vs)
475+
function promote_to_concrete(vs, tofloat=true)
476+
if isempty(vs)
477477
return vs
478478
end
479479
T = eltype(vs)
480480
if Base.isconcretetype(T) # nothing to do
481481
vs
482482
else
483483
C = foldl((t, elem)->promote_type(t, eltype(elem)), vs; init=typeof(first(vs)))
484-
convert.(C, vs)
484+
convert.(tofloat ? float(C) : C, vs)
485485
end
486-
end
486+
end

src/variables.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Takes a list of pairs of `variables=>values` and an ordered list of variables
3232
and creates the array of values in the correct order with default values when
3333
applicable.
3434
"""
35-
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=false)
35+
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term, promotetoconcrete=nothing)
3636
varlist = map(unwrap, varlist)
3737
# Edge cases where one of the arguments is effectively empty.
3838
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters || varmap === nothing
@@ -58,6 +58,7 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Sym
5858
varmap
5959
end
6060

61+
promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
6162
if promotetoconcrete
6263
vals = promote_to_concrete(vals)
6364
end
@@ -79,7 +80,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults=Dict(), check=false, to
7980
for (p, v) in pairs(varmap)
8081
varmap[p] = fixpoint_sub(v, varmap)
8182
end
82-
83+
8384
missingvars = setdiff(varlist, keys(varmap))
8485
check && (isempty(missingvars) || throw_missingvars(missingvars))
8586

0 commit comments

Comments
 (0)