Skip to content

Commit b3404cf

Browse files
Merge pull request #58 from SciML/as/setp-callback
feat: add `finalize_parameters_hook!`
2 parents ba967a9 + ba71411 commit b3404cf

File tree

4 files changed

+59
-11
lines changed

4 files changed

+59
-11
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ observed
3535
```@docs
3636
parameter_values
3737
set_parameter!
38+
finalize_parameters_hook!
3839
getp
3940
setp
4041
ParameterIndexingProxy

docs/src/complete_sii.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,15 @@ function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val
267267
end
268268
```
269269

270+
### Using `finalize_parameters_hook!`
271+
272+
The function [`finalize_parameters_hook!`](@ref) is called exactly _once_ every time the
273+
function returned by `setp` is called. This allows performing any additional bookkeeping
274+
required when parameter values are updated. [`set_parameter!`](@ref) also allows performing
275+
similar functionality, but is called for every parameter that is updated, instead of just
276+
once. Thus, `finalize_parameters_hook!` is better for expensive computations that can be
277+
performed for a bulk parameter update.
278+
270279
# The `ParameterIndexingProxy`
271280

272281
[`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the

src/parameter_indexing.jl

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ function set_parameter!(sys::AbstractArray, val, idx)
8282
end
8383
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
8484

85+
"""
86+
finalize_parameters_hook!(prob, p)
87+
88+
This is a callback run one for each call to the function returned by [`setp`](@ref)
89+
which can be used to update internal data structures when parameters are modified.
90+
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
91+
that is updated.
92+
"""
93+
finalize_parameters_hook!(prob, p) = nothing
94+
8595
"""
8696
getp(sys, p)
8797
@@ -231,22 +241,36 @@ case `parameter_values` cannot return such a mutable reference, or additional ac
231241
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
232242
implemented.
233243
"""
234-
function setp(sys, p)
244+
function setp(sys, p; run_hook = true)
235245
symtype = symbolic_type(p)
236246
elsymtype = symbolic_type(eltype(p))
237-
_setp(sys, symtype, elsymtype, p)
247+
return if run_hook
248+
let _setter! = _setp(sys, symtype, elsymtype, p), p = p
249+
function setter!(prob, args...)
250+
res = _setter!(prob, args...)
251+
finalize_parameters_hook!(prob, p)
252+
res
253+
end
254+
end
255+
else
256+
_setp(sys, symtype, elsymtype, p)
257+
end
238258
end
239259

240260
function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
241-
return function setter!(sol, val)
242-
set_parameter!(sol, val, p)
261+
return let p = p
262+
function setter!(sol, val)
263+
set_parameter!(sol, val, p)
264+
end
243265
end
244266
end
245267

246268
function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
247269
idx = parameter_index(sys, p)
248-
return function setter!(sol, val)
249-
set_parameter!(sol, val, idx)
270+
return let idx = idx
271+
function setter!(sol, val)
272+
set_parameter!(sol, val, idx)
273+
end
250274
end
251275
end
252276

@@ -256,13 +280,15 @@ for (t1, t2) in [
256280
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
257281
]
258282
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
259-
setters = setp.((sys,), p)
260-
return function setter!(sol, val)
261-
map((s!, v) -> s!(sol, v), setters, val)
283+
setters = setp.((sys,), p; run_hook = false)
284+
return let setters = setters
285+
function setter!(sol, val)
286+
map((s!, v) -> s!(sol, v), setters, val)
287+
end
262288
end
263289
end
264290
end
265291

266292
function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
267-
return setp(sys, collect(p))
293+
return setp(sys, collect(p); run_hook = false)
268294
end

test/parameter_indexing_test.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,21 @@ using Test
44
struct FakeIntegrator{S, P}
55
sys::S
66
p::P
7+
counter::Ref{Int}
78
end
89

910
function Base.getproperty(fi::FakeIntegrator, s::Symbol)
1011
s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s)
1112
end
1213
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
1314
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
15+
function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p)
16+
fi.counter[] += 1
17+
end
1418

1519
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
1620
p = [1.0, 2.0, 3.0]
17-
fi = FakeIntegrator(sys, copy(p))
21+
fi = FakeIntegrator(sys, copy(p), Ref(0))
1822
new_p = [4.0, 5.0, 6.0]
1923
@test parameter_timeseries(fi) == [0]
2024
for (sym, oldval, newval, check_inference) in [
@@ -39,19 +43,25 @@ for (sym, oldval, newval, check_inference) in [
3943
end
4044
@test get(fi) == fi.ps[sym]
4145
@test get(fi) == oldval
46+
@test fi.counter[] == 0
4247
if check_inference
4348
@inferred set!(fi, newval)
4449
else
4550
set!(fi, newval)
4651
end
52+
@test fi.counter[] == 1
53+
4754
@test get(fi) == newval
4855
set!(fi, oldval)
4956
@test get(fi) == oldval
57+
@test fi.counter[] == 2
5058

5159
fi.ps[sym] = newval
5260
@test get(fi) == newval
61+
@test fi.counter[] == 3
5362
fi.ps[sym] = oldval
5463
@test get(fi) == oldval
64+
@test fi.counter[] == 4
5565

5666
if check_inference
5767
@inferred get(p)
@@ -65,6 +75,8 @@ for (sym, oldval, newval, check_inference) in [
6575
@test get(p) == newval
6676
set!(p, oldval)
6777
@test get(p) == oldval
78+
@test fi.counter[] == 4
79+
fi.counter[] = 0
6880
end
6981

7082
for (sym, val) in [

0 commit comments

Comments
 (0)