1
- struct MTKParameters{T, D, C, E, F}
1
+ struct MTKParameters{T, D, C, E, F, G }
2
2
tunable:: T
3
3
discrete:: D
4
4
constant:: C
5
5
dependent:: E
6
- dependent_update:: F
6
+ dependent_update_iip:: F
7
+ dependent_update_oop:: G
7
8
end
8
9
9
10
function MTKParameters (sys:: AbstractSystem , p; tofloat = false , use_union = false )
@@ -19,12 +20,12 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
19
20
length (p) == length (ps) || error (" Invalid parameters" )
20
21
p = ps .=> p
21
22
end
22
- defs = Dict (default_toterm (unwrap (k)) => v for (k, v) in defaults (sys) if unwrap (k) in all_ps)
23
+ defs = Dict (default_toterm (unwrap (k)) => v for (k, v) in defaults (sys) if unwrap (k) in all_ps || default_toterm ( unwrap (k)) in all_ps )
23
24
if p isa SciMLBase. NullParameters
24
25
p = defs
25
26
else
26
- extra_params = Dict (unwrap (k) => v for (k, v) in p if ! in (unwrap (k), all_ps))
27
- p = merge (defs, Dict (default_toterm (unwrap (k)) => v for (k, v) in p if unwrap (k) in all_ps))
27
+ extra_params = Dict (unwrap (k) => v for (k, v) in p if ! in (unwrap (k), all_ps) && ! in ( default_toterm ( unwrap (k)), all_ps) )
28
+ p = merge (defs, Dict (default_toterm (unwrap (k)) => v for (k, v) in p if unwrap (k) in all_ps || default_toterm ( unwrap (k)) in all_ps ))
28
29
p = Dict (k => fixpoint_sub (v, extra_params) for (k, v) in p if ! haskey (extra_params, unwrap (k)))
29
30
end
30
31
@@ -74,39 +75,41 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
74
75
dep_exprs[idx] = wrap (fixpoint_sub (val, dependencies))
75
76
end
76
77
p = reorder_parameters (ic, parameters (sys))[begin : end - length (dep_buffer. x)]
77
- update_function = if isempty (dep_exprs. x)
78
- (_ ... ) -> ()
78
+ update_function_iip, update_function_oop = if isempty (dep_exprs. x)
79
+ nothing , nothing
79
80
else
80
- RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (build_function (dep_exprs, p... )[2 ])
81
+ oop, iip = build_function (dep_exprs, p... )
82
+ RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (iip), RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (oop)
81
83
end
82
84
# everything is an ArrayPartition so it's easy to figure out how many
83
85
# distinct vectors we have for each portion as `ArrayPartition.x`
84
86
if isempty (tunable_buffer. x)
85
- tunable_buffer = ArrayPartition ( Float64[])
87
+ tunable_buffer = Float64[]
86
88
end
87
89
if isempty (disc_buffer. x)
88
- disc_buffer = ArrayPartition ( Float64[])
90
+ disc_buffer = Float64[]
89
91
end
90
92
if isempty (const_buffer. x)
91
- const_buffer = ArrayPartition ( Float64[])
93
+ const_buffer = Float64[]
92
94
end
93
95
if isempty (dep_buffer. x)
94
- dep_buffer = ArrayPartition ( Float64[])
96
+ dep_buffer = Float64[]
95
97
end
96
98
if use_union
97
- tunable_buffer = ArrayPartition ( restrict_array_to_union (tunable_buffer) )
98
- disc_buffer = ArrayPartition ( restrict_array_to_union (disc_buffer) )
99
- const_buffer = ArrayPartition ( restrict_array_to_union (const_buffer) )
100
- dep_buffer = ArrayPartition ( restrict_array_to_union (dep_buffer) )
99
+ tunable_buffer = restrict_array_to_union (tunable_buffer)
100
+ disc_buffer = restrict_array_to_union (disc_buffer)
101
+ const_buffer = restrict_array_to_union (const_buffer)
102
+ dep_buffer = restrict_array_to_union (dep_buffer)
101
103
elseif tofloat
102
- tunable_buffer = ArrayPartition ( Float64 .(tunable_buffer) )
103
- disc_buffer = ArrayPartition ( Float64 .(disc_buffer) )
104
- const_buffer = ArrayPartition ( Float64 .(const_buffer) )
105
- dep_buffer = ArrayPartition ( Float64 .(dep_buffer) )
104
+ tunable_buffer = Float64 .(tunable_buffer)
105
+ disc_buffer = Float64 .(disc_buffer)
106
+ const_buffer = Float64 .(const_buffer)
107
+ dep_buffer = Float64 .(dep_buffer)
106
108
end
107
109
return MTKParameters{typeof (tunable_buffer), typeof (disc_buffer), typeof (const_buffer),
108
- typeof (dep_buffer), typeof (update_function)}(tunable_buffer,
109
- disc_buffer, const_buffer, dep_buffer, update_function)
110
+ typeof (dep_buffer), typeof (update_function_iip), typeof (update_function_oop)}(
111
+ tunable_buffer, disc_buffer, const_buffer, dep_buffer, update_function_iip,
112
+ update_function_oop)
110
113
end
111
114
112
115
SciMLStructures. isscimlstructure (:: MTKParameters ) = true
@@ -121,19 +124,27 @@ for (Portion, field) in [
121
124
@eval function SciMLStructures. canonicalize (:: $Portion , p:: MTKParameters )
122
125
function repack (values)
123
126
p.$ field .= values
124
- p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
127
+ if p. dependent_update_iip != = nothing
128
+ p. dependent_update_iip (p. dependent, p... )
129
+ end
130
+ p
125
131
end
126
132
return p.$ field, repack, true
127
133
end
128
134
129
135
@eval function SciMLStructures. replace (:: $Portion , p:: MTKParameters , newvals)
130
136
@set! p.$ field = newvals
131
- p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
137
+ if p. dependent_update_oop != = nothing
138
+ @set! p. dependent = ArrayPartition (p. dependent_update_oop (p... ))
139
+ end
140
+ p
132
141
end
133
142
134
143
@eval function SciMLStructures. replace! (:: $Portion , p:: MTKParameters , newvals)
135
144
p.$ field .= newvals
136
- p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
145
+ if p. dependent_update_iip != = nothing
146
+ p. dependent_update_iip (p. dependent, p... )
147
+ end
137
148
nothing
138
149
end
139
150
end
@@ -166,26 +177,32 @@ function SymbolicIndexingInterface.set_parameter!(p::MTKParameters, val, idx::Pa
166
177
else
167
178
error (" Unhandled portion $portion " )
168
179
end
169
- p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
180
+ if p. dependent_update_iip != = nothing
181
+ p. dependent_update_iip (p. dependent, p... )
182
+ end
170
183
end
171
184
185
+ _subarrays (v:: AbstractVector ) = isempty (v) ? () : (v,)
186
+ _subarrays (v:: ArrayPartition ) = v. x
187
+ _num_subarrays (v:: AbstractVector ) = 1
188
+ _num_subarrays (v:: ArrayPartition ) = length (v. x)
172
189
# for compiling callbacks
173
190
# getindex indexes the vectors, setindex! linearly indexes values
174
191
# it's inconsistent, but we need it to be this way
175
192
function Base. getindex (buf:: MTKParameters , i)
176
193
if ! isempty (buf. tunable)
177
- i <= length (buf. tunable. x ) && return buf. tunable. x [i]
178
- i -= length (buf. tunable. x )
194
+ i <= _num_subarrays (buf. tunable) && return _subarrays ( buf. tunable) [i]
195
+ i -= _num_subarrays (buf. tunable)
179
196
end
180
197
if ! isempty (buf. discrete)
181
- i <= length (buf. discrete. x ) && return buf. discrete. x [i]
182
- i -= length (buf. discrete. x )
198
+ i <= _num_subarrays (buf. discrete) && return _subarrays ( buf. discrete) [i]
199
+ i -= _num_subarrays (buf. discrete)
183
200
end
184
201
if ! isempty (buf. constant)
185
- i <= length (buf. constant. x ) && return buf. constant. x [i]
186
- i -= length (buf. constant. x )
202
+ i <= _num_subarrays (buf. constant) && return _subarrays ( buf. constant) [i]
203
+ i -= _num_subarrays (buf. constant)
187
204
end
188
- isempty (buf. dependent) || return buf. dependent. x [i]
205
+ isempty (buf. dependent) || return _subarrays ( buf. dependent) [i]
189
206
throw (BoundsError (buf, i))
190
207
end
191
208
function Base. setindex! (buf:: MTKParameters , val, i)
@@ -196,15 +213,17 @@ function Base.setindex!(buf::MTKParameters, val, i)
196
213
else
197
214
buf. constant[i - length (buf. tunable) - length (buf. discrete)] = val
198
215
end
199
- buf. dependent_update (buf. dependent, buf. tunable. x... , buf. discrete. x... , buf. constant. x... )
216
+ if buf. dependent_update_iip != = nothing
217
+ buf. dependent_update_iip (buf. dependent, buf... )
218
+ end
200
219
end
201
220
202
221
function Base. iterate (buf:: MTKParameters , state = 1 )
203
222
total_len = 0
204
- isempty (buf. tunable) || (total_len += length (buf. tunable. x ))
205
- isempty (buf. discrete) || (total_len += length (buf. discrete. x ))
206
- isempty (buf. constant) || (total_len += length (buf. constant. x ))
207
- isempty (buf. dependent) || (total_len += length (buf. dependent. x ))
223
+ isempty (buf. tunable) || (total_len += _num_subarrays (buf. tunable))
224
+ isempty (buf. discrete) || (total_len += _num_subarrays (buf. discrete))
225
+ isempty (buf. constant) || (total_len += _num_subarrays (buf. constant))
226
+ isempty (buf. dependent) || (total_len += _num_subarrays (buf. dependent))
208
227
if state <= total_len
209
228
return (buf[state], state + 1 )
210
229
else
@@ -229,14 +248,24 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
229
248
p_big = p_big
230
249
231
250
function (p_small_inner)
232
- tunable, repack, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), p_big)
233
- tunable[input_idxs] .= p_small_inner
234
- p_big = repack (tunable)
235
- pf (p_big)
251
+ for (i, val) in zip (input_idxs, p_small_inner)
252
+ set_parameter! (p_big, val, i)
253
+ end
254
+ # tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
255
+ # tunable[input_idxs] .= p_small_inner
256
+ # p_big = repack(tunable)
257
+ return if pf isa SciMLBase. ParamJacobianWrapper
258
+ buffer = similar (p_big. tunable, size (pf. u))
259
+ pf (buffer, p_big)
260
+ buffer
261
+ else
262
+ pf (p_big)
263
+ end
236
264
end
237
265
end
238
- tunable, _, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), p)
239
- p_small = tunable[input_idxs]
266
+ # tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
267
+ # p_small = tunable[input_idxs]
268
+ p_small = parameter_values .((p,), input_idxs)
240
269
cfg = ForwardDiff. JacobianConfig (p_closure, p_small, chunk, tag)
241
270
ForwardDiff. jacobian (p_closure, p_small, cfg, Val (false ))
242
271
end
0 commit comments