Skip to content

Commit 24ca6e9

Browse files
Merge pull request #2973 from AayushSabharwal/as/fix-everything
fix: fix several bugs, get MTK to precompile
2 parents dc98d54 + df80490 commit 24ca6e9

File tree

10 files changed

+19
-20
lines changed

10 files changed

+19
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ SparseArrays = "1"
112112
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
113113
StaticArrays = "0.10, 0.11, 0.12, 1.0"
114114
SymbolicIndexingInterface = "0.3.28"
115-
SymbolicUtils = "3.1.2"
115+
SymbolicUtils = "3.2"
116116
Symbolics = "6"
117117
URIs = "1"
118118
UnPack = "0.1, 1.0"

src/systems/abstractsystem.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,14 +2266,13 @@ function linearization_function(sys::AbstractSystem, inputs,
22662266
end
22672267
x0 = merge(defaults_and_guesses(sys), op)
22682268
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2269-
sys_ps = MTKParameters(sys, p, x0; eval_expression, eval_module)
2269+
sys_ps = MTKParameters(sys, p, x0)
22702270
else
22712271
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
22722272
end
22732273
p[get_iv(sys)] = NaN
22742274
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
2275-
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op);
2276-
eval_expression, eval_module)
2275+
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
22772276
initsys_ps = parameters(initsys)
22782277
p_getter = build_explicit_observed_function(
22792278
sys, initsys_ps; eval_expression, eval_module)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ function get_u0_p(sys,
722722
if symbolic_u0
723723
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
724724
else
725-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
725+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
726726
end
727727
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
728728
p = p === nothing ? SciMLBase.NullParameters() : p
@@ -732,7 +732,7 @@ end
732732

733733
function get_u0(
734734
sys, u0map, parammap = nothing; symbolic_u0 = false,
735-
toterm = default_toterm, t0 = nothing)
735+
toterm = default_toterm, t0 = nothing, use_union = true)
736736
dvs = unknowns(sys)
737737
ps = parameters(sys)
738738
defs = defaults(sys)
@@ -757,7 +757,7 @@ function get_u0(
757757
u0 = varmap_to_vars(
758758
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
759759
else
760-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, toterm)
760+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
761761
end
762762
t0 !== nothing && delete!(defs, get_iv(sys))
763763
return u0, defs
@@ -836,13 +836,13 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
836836

837837
if has_index_cache(sys) && get_index_cache(sys) !== nothing
838838
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0,
839-
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t)
839+
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t, use_union)
840840
check_eqs_u0(eqs, dvs, u0; kwargs...)
841841
p = if parammap === nothing ||
842842
parammap == SciMLBase.NullParameters() && isempty(defs)
843843
nothing
844844
else
845-
MTKParameters(sys, parammap, trueinit; t0 = t, eval_expression, eval_module)
845+
MTKParameters(sys, parammap, trueinit; t0 = t)
846846
end
847847
else
848848
u0, p, defs = get_u0_p(sys,

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
254254
end
255255
if has_index_cache(sys) && get_index_cache(sys) !== nothing
256256
u0, defs = get_u0(sys, trueu0map, parammap)
257-
p = MTKParameters(sys, parammap, trueu0map; eval_expression, eval_module)
257+
p = MTKParameters(sys, parammap, trueu0map)
258258
else
259259
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
260260
end

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
352352

353353
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
354354
if has_index_cache(sys) && get_index_cache(sys) !== nothing
355-
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
355+
p = MTKParameters(sys, parammap, u0map)
356356
else
357357
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
358358
end
@@ -458,7 +458,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
458458

459459
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
460460
if has_index_cache(sys) && get_index_cache(sys) !== nothing
461-
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
461+
p = MTKParameters(sys, parammap, u0map)
462462
else
463463
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
464464
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
396396
if has_index_cache(sys) && get_index_cache(sys) !== nothing
397397
u0, defs = get_u0(sys, u0map, parammap)
398398
check_eqs_u0(eqs, dvs, u0; kwargs...)
399-
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
399+
p = MTKParameters(sys, parammap, u0map)
400400
else
401401
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
402402
check_eqs_u0(eqs, dvs, u0; kwargs...)

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
290290
if parammap isa MTKParameters
291291
p = parammap
292292
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
293-
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
293+
p = MTKParameters(sys, parammap, u0map)
294294
else
295295
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
296296
end
@@ -524,7 +524,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,
524524

525525
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
526526
if has_index_cache(sys) && get_index_cache(sys) !== nothing
527-
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
527+
p = MTKParameters(sys, parammap, u0map)
528528
else
529529
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
530530
end

src/utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
689689
if use_union
690690
C = Union{C, E}
691691
else
692-
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
693-
C = E
692+
C2 = promote_type(C, E)
693+
@assert C2==E || C2==C "`promote_to_concrete` can't make type $E uniform with $C"
694+
C = C2
694695
end
695696
end
696697

src/variables.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,11 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
208208
val = unwrap(fixpoint_sub(var, varmap; operator = Symbolics.Operator))
209209
if !isequal(val, var)
210210
values[var] = val
211-
T = promote_type(T, typeof(val))
212211
end
213212
end
214213
missingvars = setdiff(varlist, collect(keys(values)))
215214
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
216-
return [T(values[unwrap(var)]) for var in varlist]
215+
return [values[unwrap(var)] for var in varlist]
217216
end
218217

219218
function varmap_with_toterm(varmap; toterm = Symbolics.diff2term)

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ end
11771177
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
11781178

11791179
function x_at_1(P)
1180-
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P])
1180+
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P], use_union = false)
11811181
return solve(prob, Tsit5())(1.0)
11821182
end
11831183

0 commit comments

Comments
 (0)