@@ -189,28 +189,73 @@ function setu(bi::BatchedInterface)
189189 numprobs = length (bi. system_to_symbol_subset)
190190 probnames = [Symbol (:prob , i) for i in 1 : numprobs]
191191
192- fnbody = quote end
193- for (sys_idx, subset) in enumerate (bi. system_to_symbol_subset)
194- probname = probnames[sys_idx]
195- for (idx_in_subset, idx_in_union) in enumerate (subset)
196- idx = bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
197- isstate = bi. system_to_isstate[sys_idx][idx_in_subset]
198- setter = isstate ? set_state! : set_parameter!
199- push! (fnbody. args, :($ setter ($ probname, vals[$ idx_in_union], $ idx)))
192+ full_update_fnexpr = let fnbody = quote end
193+ for (sys_idx, subset) in enumerate (bi. system_to_symbol_subset)
194+ probname = probnames[sys_idx]
195+ for (idx_in_subset, idx_in_union) in enumerate (subset)
196+ idx = bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
197+ isstate = bi. system_to_isstate[sys_idx][idx_in_subset]
198+ setter = isstate ? set_state! : set_parameter!
199+ push! (fnbody. args, :($ setter ($ probname, vals[$ idx_in_union], $ idx)))
200+ end
201+ # also run hook
202+ if ! all (bi. system_to_isstate[sys_idx])
203+ paramidxs = [bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
204+ for idx_in_subset in 1 : length (subset)
205+ if ! bi. system_to_isstate[sys_idx][idx_in_subset]]
206+ push! (fnbody. args, :($ finalize_parameters_hook! ($ probname, $ paramidxs)))
207+ end
200208 end
201- # also run hook
202- if ! all (bi. system_to_isstate[sys_idx])
203- paramidxs = [bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
204- for idx_in_subset in 1 : length (subset)
205- if ! bi. system_to_isstate[sys_idx][idx_in_subset]]
206- push! (fnbody. args, :($ finalize_parameters_hook! ($ probname, $ paramidxs)))
209+ push! (fnbody. args, :(return vals))
210+ Expr (
211+ :function ,
212+ Expr (:tuple , probnames... , :vals ),
213+ fnbody
214+ )
215+ end
216+
217+ partial_update_fnexpr = let fnbody = quote end
218+ curfnbody = fnbody
219+ for (sys_idx, subset) in enumerate (bi. system_to_symbol_subset)
220+ newcurfnbody = if sys_idx == 1
221+ Expr (:if , :(idx == $ sys_idx))
222+ else
223+ Expr (:elseif , :(idx == $ sys_idx))
224+ end
225+ push! (curfnbody. args, newcurfnbody)
226+ curfnbody = newcurfnbody
227+
228+ ifbody = quote end
229+ push! (curfnbody. args, ifbody)
230+
231+ probname = :prob
232+ for (idx_in_subset, idx_in_union) in enumerate (subset)
233+ idx = bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
234+ isstate = bi. system_to_isstate[sys_idx][idx_in_subset]
235+ setter = isstate ? set_state! : set_parameter!
236+ push! (ifbody. args, :($ setter ($ probname, vals[$ idx_in_union], $ idx)))
237+ end
238+ # also run hook
239+ if ! all (bi. system_to_isstate[sys_idx])
240+ paramidxs = [bi. system_to_symbol_indexes[sys_idx][idx_in_subset]
241+ for idx_in_subset in 1 : length (subset)
242+ if ! bi. system_to_isstate[sys_idx][idx_in_subset]]
243+ push! (ifbody. args, :($ finalize_parameters_hook! ($ probname, $ paramidxs)))
244+ end
207245 end
246+ push! (curfnbody. args, :(error (" Invalid problem index $idx " )))
247+ push! (fnbody. args, :(return nothing ))
248+ Expr (
249+ :function ,
250+ Expr (:tuple , :prob , :idx , :vals ),
251+ fnbody
252+ )
253+ end
254+ return let full_update = @RuntimeGeneratedFunction (full_update_fnexpr),
255+ partial_update = @RuntimeGeneratedFunction (partial_update_fnexpr)
256+
257+ setter! (args... ) = full_update (args... )
258+ setter! (prob, idx:: Int , vals:: AbstractVector ) = partial_update (prob, idx, vals)
259+ setter!
208260 end
209- push! (fnbody. args, :(return vals))
210- fnexpr = Expr (
211- :function ,
212- Expr (:tuple , probnames... , :vals ),
213- fnbody
214- )
215- return @RuntimeGeneratedFunction (fnexpr)
216261end
0 commit comments