Skip to content

Commit 0e8e037

Browse files
feat: support updating individual problems with BatchedInterface
1 parent ac86106 commit 0e8e037

File tree

2 files changed

+73
-21
lines changed

2 files changed

+73
-21
lines changed

src/batched_interface.jl

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -189,28 +189,73 @@ function setu(bi::BatchedInterface)
189189
numprobs = length(bi.system_to_symbol_subset)
190190
probnames = [Symbol(:prob, i) for i in 1:numprobs]
191191

192-
fnbody = quote end
193-
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
194-
probname = probnames[sys_idx]
195-
for (idx_in_subset, idx_in_union) in enumerate(subset)
196-
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
197-
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
198-
setter = isstate ? set_state! : set_parameter!
199-
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
192+
full_update_fnexpr = let fnbody = quote end
193+
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
194+
probname = probnames[sys_idx]
195+
for (idx_in_subset, idx_in_union) in enumerate(subset)
196+
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
197+
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
198+
setter = isstate ? set_state! : set_parameter!
199+
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
200+
end
201+
# also run hook
202+
if !all(bi.system_to_isstate[sys_idx])
203+
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
204+
for idx_in_subset in 1:length(subset)
205+
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
206+
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
207+
end
200208
end
201-
# also run hook
202-
if !all(bi.system_to_isstate[sys_idx])
203-
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
204-
for idx_in_subset in 1:length(subset)
205-
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
206-
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
209+
push!(fnbody.args, :(return vals))
210+
Expr(
211+
:function,
212+
Expr(:tuple, probnames..., :vals),
213+
fnbody
214+
)
215+
end
216+
217+
partial_update_fnexpr = let fnbody = quote end
218+
curfnbody = fnbody
219+
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
220+
newcurfnbody = if sys_idx == 1
221+
Expr(:if, :(idx == $sys_idx))
222+
else
223+
Expr(:elseif, :(idx == $sys_idx))
224+
end
225+
push!(curfnbody.args, newcurfnbody)
226+
curfnbody = newcurfnbody
227+
228+
ifbody = quote end
229+
push!(curfnbody.args, ifbody)
230+
231+
probname = :prob
232+
for (idx_in_subset, idx_in_union) in enumerate(subset)
233+
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
234+
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
235+
setter = isstate ? set_state! : set_parameter!
236+
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
237+
end
238+
# also run hook
239+
if !all(bi.system_to_isstate[sys_idx])
240+
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
241+
for idx_in_subset in 1:length(subset)
242+
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
243+
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
244+
end
207245
end
246+
push!(curfnbody.args, :(error("Invalid problem index $idx")))
247+
push!(fnbody.args, :(return nothing))
248+
Expr(
249+
:function,
250+
Expr(:tuple, :prob, :idx, :vals),
251+
fnbody
252+
)
253+
end
254+
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
255+
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)
256+
257+
setter!(args...) = full_update(args...)
258+
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
259+
setter!
208260
end
209-
push!(fnbody.args, :(return vals))
210-
fnexpr = Expr(
211-
:function,
212-
Expr(:tuple, probnames..., :vals),
213-
fnbody
214-
)
215-
return @RuntimeGeneratedFunction(fnexpr)
216261
end

test/batched_interface_test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,10 @@ setter!(probs..., buf)
4747
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
4848
# Similarly for :f
4949
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]
50+
51+
buf ./= 100
52+
setter!(probs[1], 1, buf)
53+
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
54+
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]
55+
56+
@test_throws ErrorException setter!(probs[1], 4, buf)

0 commit comments

Comments
 (0)