@@ -43,35 +43,62 @@ function MTKParameters(
43
43
end
44
44
defs = merge (defs, u0)
45
45
defs = merge (Dict (eq. lhs => eq. rhs for eq in observed (sys)), defs)
46
- p = merge (defs, p)
47
- p = merge (Dict (unwrap (k) => v for (k, v) in p),
48
- Dict (default_toterm (unwrap (k)) => v for (k, v) in p))
49
- p = Dict (unwrap (k) => fixpoint_sub (v, p) for (k, v) in p)
50
- for (sym, _) in p
51
- if iscall (sym) && operation (sym) === getindex &&
52
- first (arguments (sym)) in all_ps
53
- error (" Scalarized parameter values ($sym ) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`" )
46
+ bigdefs = merge (defs, p)
47
+ p = Dict ()
48
+ missing_params = Set ()
49
+ pdeps = has_parameter_dependencies (sys) ? parameter_dependencies (sys) : nothing
50
+
51
+ for sym in all_ps
52
+ ttsym = default_toterm (sym)
53
+ isarr = iscall (sym) && operation (sym) === getindex
54
+ arrparent = isarr ? arguments (sym)[1 ] : nothing
55
+ ttarrparent = isarr ? default_toterm (arrparent) : nothing
56
+ pname = hasname (sym) ? getname (sym) : nothing
57
+ ttpname = hasname (ttsym) ? getname (ttsym) : nothing
58
+ p[sym] = p[ttsym] = if haskey (bigdefs, sym)
59
+ bigdefs[sym]
60
+ elseif haskey (bigdefs, ttsym)
61
+ bigdefs[ttsym]
62
+ elseif haskey (bigdefs, pname)
63
+ isarr ? bigdefs[pname][arguments (sym)[2 : end ]. .. ] : bigdefs[pname]
64
+ elseif haskey (bigdefs, ttpname)
65
+ isarr ? bigdefs[ttpname][arguments (sym)[2 : end ]. .. ] : bigdefs[pname]
66
+ elseif isarr && haskey (bigdefs, arrparent)
67
+ bigdefs[arrparent][arguments (sym)[2 : end ]. .. ]
68
+ elseif isarr && haskey (bigdefs, ttarrparent)
69
+ bigdefs[ttarrparent][arguments (sym)[2 : end ]. .. ]
54
70
end
71
+ if get (p, sym, nothing ) === nothing
72
+ push! (missing_params, sym)
73
+ continue
74
+ end
75
+ # We may encounter the `ttsym` version first, add it to `missing_params`
76
+ # then encounter the "normal" version of a parameter or vice versa
77
+ # Remove the old one in `missing_params` just in case
78
+ delete! (missing_params, sym)
79
+ delete! (missing_params, ttsym)
55
80
end
56
81
57
- missing_params = Set ()
58
- for idxmap in (ic. tunable_idx, ic. discrete_idx, ic. constant_idx, ic. nonnumeric_idx)
59
- for sym in keys (idxmap)
60
- sym isa Symbol && continue
61
- haskey (p, sym) && continue
62
- hasname (sym) && haskey (p, getname (sym)) && continue
82
+ if pdeps != = nothing
83
+ for (sym, expr) in pdeps
84
+ sym = unwrap (sym)
63
85
ttsym = default_toterm (sym)
64
- haskey (p, ttsym) && continue
65
- hasname (ttsym) && haskey (p, getname (ttsym)) && continue
66
-
67
- iscall (sym) && operation (sym) === getindex && haskey (p, arguments (sym)[1 ]) &&
68
- continue
69
- push! (missing_params, sym)
86
+ delete! (missing_params, sym)
87
+ delete! (missing_params, ttsym)
88
+ p[sym] = p[ttsym] = expr
70
89
end
71
90
end
72
91
73
92
isempty (missing_params) || throw (MissingParametersError (collect (missing_params)))
74
93
94
+ p = Dict (unwrap (k) => fixpoint_sub (v, bigdefs) for (k, v) in p)
95
+ for (sym, _) in p
96
+ if iscall (sym) && operation (sym) === getindex &&
97
+ first (arguments (sym)) in all_ps
98
+ error (" Scalarized parameter values ($sym ) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`" )
99
+ end
100
+ end
101
+
75
102
tunable_buffer = Tuple (Vector {temp.type} (undef, temp. length)
76
103
for temp in ic. tunable_buffer_sizes)
77
104
disc_buffer = Tuple (Vector {temp.type} (undef, temp. length)
@@ -135,8 +162,7 @@ function MTKParameters(
135
162
# Don't narrow nonnumeric types
136
163
nonnumeric_buffer = nonnumeric_buffer
137
164
138
- if has_parameter_dependencies (sys) &&
139
- (pdeps = parameter_dependencies (sys)) != = nothing
165
+ if pdeps != = nothing
140
166
pdeps = Dict (k => fixpoint_sub (v, pdeps) for (k, v) in pdeps)
141
167
dep_exprs = ArrayPartition ((Any[missing for _ in 1 : length (v)] for v in dep_buffer). .. )
142
168
for (sym, val) in pdeps
0 commit comments