Skip to content

Commit ae3ddda

Browse files
Merge pull request #2575 from AayushSabharwal/as/initialization-fix
fix: fix initialization and observed bugs
2 parents 4db0053 + 570cf35 commit ae3ddda

File tree

5 files changed

+50
-6
lines changed

5 files changed

+50
-6
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,6 @@ function build_explicit_observed_function(sys, ts;
460460
else
461461
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
462462
end
463-
if isempty(ps)
464-
ps = (DestructuredArgs([]),)
465-
end
466463
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
467464
if inputs === nothing
468465
args = [dvs, ps..., ivs...]

src/systems/parameter_buffer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function MTKParameters(
3939
u0 = Dict()
4040
end
4141
defs = merge(defs, u0)
42-
defs = merge(defs, Dict(eq.lhs => eq.rhs for eq in observed(sys)))
42+
defs = merge(Dict(eq.lhs => eq.rhs for eq in observed(sys)), defs)
4343
p = merge(defs, p)
4444
p = merge(Dict(unwrap(k) => v for (k, v) in p),
4545
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))

src/variables.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
192192
toterm = Symbolics.diff2term, initialization_phase = false)
193193
varmap = canonicalize_varmap(varmap; toterm)
194194
defaults = canonicalize_varmap(defaults; toterm)
195+
varmap = merge(defaults, varmap)
195196
values = Dict()
196197
for var in varlist
197198
var = unwrap(var)
198-
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap; operator = Symbolics.Operator),
199-
defaults; operator = Symbolics.Operator))
199+
val = unwrap(fixpoint_sub(var, varmap; operator = Symbolics.Operator))
200200
if symbolic_type(val) === NotSymbolic()
201201
values[var] = val
202202
end
@@ -211,6 +211,11 @@ function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
211211
for (k, v) in varmap
212212
new_varmap[unwrap(k)] = unwrap(v)
213213
new_varmap[toterm(unwrap(k))] = unwrap(v)
214+
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
215+
for i in eachindex(k)
216+
new_varmap[k[i]] = v[i]
217+
end
218+
end
214219
end
215220
return new_varmap
216221
end

test/initial_values.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,30 @@ getter = getu(sys, [x..., y, z...])
3333
@test getter(get_u0(
3434
sys, [y => 2w, w => 3.0, z[1] => 2p1, z[2] => 3p2], [p1 => 3.0, p2 => 4.0])[1]) ==
3535
[1.0, 2.0, 3.0, 6.0, 6.0, 12.0]
36+
37+
# Issue#2566
38+
@variables X(t)
39+
@parameters p1 p2 p3
40+
41+
p_vals = [p1 => 1.0, p2 => 2.0]
42+
u_vals = [X => 3.0]
43+
44+
var_vals = [p1 => 1.0, p2 => 2.0, X => 3.0]
45+
desired_values = [p1, p2, p3]
46+
defaults = Dict([p3 => X])
47+
vals = ModelingToolkit.varmap_to_vars(var_vals, desired_values; defaults = defaults)
48+
@test vals == [1.0, 2.0, 3.0]
49+
50+
# Issue#2565
51+
# Create ODESystem.
52+
@variables X1(t) X2(t)
53+
@parameters k1 k2 Γ[1:1]=X1 + X2
54+
eq = D(X1) ~ -k1 * X1 + k2 * (-X1 + Γ[1])
55+
obs = X2 ~ Γ[1] - X1
56+
@mtkbuild osys_m = ODESystem([eq], t, [X1], [k1, k2, Γ[1]]; observed = [X2 ~ Γ[1] - X1])
57+
58+
# Creates ODEProblem.
59+
u0 = [X1 => 1.0, X2 => 2.0]
60+
tspan = (0.0, 1.0)
61+
ps = [k1 => 1.0, k2 => 5.0]
62+
@test_nowarn oprob = ODEProblem(osys_m, u0, tspan, ps)

test/nonlinearsystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,18 @@ testdict = Dict([:test => 1])
242242
@test prob_.u0 == [1.0, 2.0, 1.0]
243243
@test prob_.p == MTKParameters(sys, [a => 2.0, b => 1.0, c => 1.0])
244244
end
245+
246+
@testset "Observed function generation without parameters" begin
247+
@variables x y z
248+
249+
eqs = [0 ~ x + sin(y),
250+
0 ~ z - cos(x),
251+
0 ~ x * y]
252+
@named ns = NonlinearSystem(eqs, [x, y, z], [])
253+
ns = complete(ns)
254+
vs = [unknowns(ns); parameters(ns)]
255+
ss_mtk = structural_simplify(ns)
256+
prob = NonlinearProblem(ss_mtk, vs .=> 1.0)
257+
sol = solve(prob)
258+
@test_nowarn sol[unknowns(ns)]
259+
end

0 commit comments

Comments
 (0)