Skip to content

Commit 3b8d2d6

Browse files
Merge pull request #2785 from AayushSabharwal/obsfixedpoint
fix: avoid infinite loops in MTKParameters initialization
2 parents 72044e4 + 9b4dd00 commit 3b8d2d6

File tree

1 file changed

+48
-22
lines changed

1 file changed

+48
-22
lines changed

src/systems/parameter_buffer.jl

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,35 +43,62 @@ function MTKParameters(
4343
end
4444
defs = merge(defs, u0)
4545
defs = merge(Dict(eq.lhs => eq.rhs for eq in observed(sys)), defs)
46-
p = merge(defs, p)
47-
p = merge(Dict(unwrap(k) => v for (k, v) in p),
48-
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
49-
p = Dict(unwrap(k) => fixpoint_sub(v, p) for (k, v) in p)
50-
for (sym, _) in p
51-
if iscall(sym) && operation(sym) === getindex &&
52-
first(arguments(sym)) in all_ps
53-
error("Scalarized parameter values ($sym) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
46+
bigdefs = merge(defs, p)
47+
p = Dict()
48+
missing_params = Set()
49+
pdeps = has_parameter_dependencies(sys) ? parameter_dependencies(sys) : nothing
50+
51+
for sym in all_ps
52+
ttsym = default_toterm(sym)
53+
isarr = iscall(sym) && operation(sym) === getindex
54+
arrparent = isarr ? arguments(sym)[1] : nothing
55+
ttarrparent = isarr ? default_toterm(arrparent) : nothing
56+
pname = hasname(sym) ? getname(sym) : nothing
57+
ttpname = hasname(ttsym) ? getname(ttsym) : nothing
58+
p[sym] = p[ttsym] = if haskey(bigdefs, sym)
59+
bigdefs[sym]
60+
elseif haskey(bigdefs, ttsym)
61+
bigdefs[ttsym]
62+
elseif haskey(bigdefs, pname)
63+
isarr ? bigdefs[pname][arguments(sym)[2:end]...] : bigdefs[pname]
64+
elseif haskey(bigdefs, ttpname)
65+
isarr ? bigdefs[ttpname][arguments(sym)[2:end]...] : bigdefs[pname]
66+
elseif isarr && haskey(bigdefs, arrparent)
67+
bigdefs[arrparent][arguments(sym)[2:end]...]
68+
elseif isarr && haskey(bigdefs, ttarrparent)
69+
bigdefs[ttarrparent][arguments(sym)[2:end]...]
5470
end
71+
if get(p, sym, nothing) === nothing
72+
push!(missing_params, sym)
73+
continue
74+
end
75+
# We may encounter the `ttsym` version first, add it to `missing_params`
76+
# then encounter the "normal" version of a parameter or vice versa
77+
# Remove the old one in `missing_params` just in case
78+
delete!(missing_params, sym)
79+
delete!(missing_params, ttsym)
5580
end
5681

57-
missing_params = Set()
58-
for idxmap in (ic.tunable_idx, ic.discrete_idx, ic.constant_idx, ic.nonnumeric_idx)
59-
for sym in keys(idxmap)
60-
sym isa Symbol && continue
61-
haskey(p, sym) && continue
62-
hasname(sym) && haskey(p, getname(sym)) && continue
82+
if pdeps !== nothing
83+
for (sym, expr) in pdeps
84+
sym = unwrap(sym)
6385
ttsym = default_toterm(sym)
64-
haskey(p, ttsym) && continue
65-
hasname(ttsym) && haskey(p, getname(ttsym)) && continue
66-
67-
iscall(sym) && operation(sym) === getindex && haskey(p, arguments(sym)[1]) &&
68-
continue
69-
push!(missing_params, sym)
86+
delete!(missing_params, sym)
87+
delete!(missing_params, ttsym)
88+
p[sym] = p[ttsym] = expr
7089
end
7190
end
7291

7392
isempty(missing_params) || throw(MissingParametersError(collect(missing_params)))
7493

94+
p = Dict(unwrap(k) => fixpoint_sub(v, bigdefs) for (k, v) in p)
95+
for (sym, _) in p
96+
if iscall(sym) && operation(sym) === getindex &&
97+
first(arguments(sym)) in all_ps
98+
error("Scalarized parameter values ($sym) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
99+
end
100+
end
101+
75102
tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length)
76103
for temp in ic.tunable_buffer_sizes)
77104
disc_buffer = Tuple(Vector{temp.type}(undef, temp.length)
@@ -135,8 +162,7 @@ function MTKParameters(
135162
# Don't narrow nonnumeric types
136163
nonnumeric_buffer = nonnumeric_buffer
137164

138-
if has_parameter_dependencies(sys) &&
139-
(pdeps = parameter_dependencies(sys)) !== nothing
165+
if pdeps !== nothing
140166
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
141167
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
142168
for (sym, val) in pdeps

0 commit comments

Comments
 (0)