Skip to content

Commit a110ee4

Browse files
refactor: remove ParameterIndexingProxy, add get_p and set_p functions
1 parent 4e75bff commit a110ee4

File tree

6 files changed

+116
-87
lines changed

6 files changed

+116
-87
lines changed

src/SymbolicIndexingInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ include("interface.jl")
1313
export SymbolCache
1414
include("symbol_cache.jl")
1515

16-
export ParameterIndexingProxy, parameter_values
17-
include("parameter_indexing_proxy.jl")
16+
export parameter_values, getp, setp
17+
include("parameter_indexing.jl")
1818

1919
@static if !isdefined(Base, :get_extension)
2020
function __init__()

src/parameter_indexing.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
parameter_values(p)
3+
4+
Return an indexable collection containing the value of each parameter in `p`.
5+
"""
6+
function parameter_values end
7+
8+
"""
9+
getp(sys, p)
10+
11+
Return a function that takes an integrator or solution of `sys`, and returns the value of
12+
the parameter `p`. Requires that the integrator or solution implement
13+
[`parameter_values`](@ref).
14+
"""
15+
function getp(sys, p)
16+
symtype = symbolic_type(p)
17+
elsymtype = symbolic_type(eltype(p))
18+
if symtype != NotSymbolic()
19+
return _getp(sys, symtype, p)
20+
else
21+
return _getp(sys, elsymtype, p)
22+
end
23+
end
24+
25+
function _getp(sys, ::NotSymbolic, p)
26+
return function getter(sol)
27+
return parameter_values(sol)[p]
28+
end
29+
end
30+
31+
function _getp(sys, ::ScalarSymbolic, p)
32+
idx = parameter_index(sys, p)
33+
return function getter(sol)
34+
return parameter_values(sol)[idx]
35+
end
36+
end
37+
38+
function _getp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray})
39+
idxs = parameter_index.((sys,), p)
40+
return function getter(sol)
41+
return getindex.((parameter_values(sol),), idxs)
42+
end
43+
end
44+
45+
function _getp(sys, ::ArraySymbolic, p)
46+
return getp(sys, collect(p))
47+
end
48+
49+
"""
50+
setp(sys, p)
51+
52+
Return a function that takes an integrator of `sys` and a value, and sets the
53+
the parameter `p` to that value. Requires that the integrator implement
54+
[`parameter_values`](@ref) and the returned collection be a mutable reference
55+
to the parameter vector in the integrator.
56+
"""
57+
function setp(sys, p)
58+
symtype = symbolic_type(p)
59+
elsymtype = symbolic_type(eltype(p))
60+
if symtype != NotSymbolic()
61+
return _setp(sys, symtype, p)
62+
else
63+
return _setp(sys, elsymtype, p)
64+
end
65+
end
66+
67+
function _setp(sys, ::NotSymbolic, p)
68+
return function setter!(sol, val)
69+
parameter_values(sol)[p] = val
70+
end
71+
end
72+
73+
function _setp(sys, ::ScalarSymbolic, p)
74+
idx = parameter_index(sys, p)
75+
return function setter!(sol, val)
76+
parameter_values(sol)[idx] = val
77+
end
78+
end
79+
80+
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple,AbstractArray})
81+
idxs = parameter_index.((sys,), p)
82+
return function setter!(sol, val)
83+
setindex!.((parameter_values(sol),), val, idxs)
84+
end
85+
end
86+
87+
function _setp(sys, ::ArraySymbolic, p)
88+
return setp(sys, collect(p))
89+
end

src/parameter_indexing_proxy.jl

Lines changed: 0 additions & 51 deletions
This file was deleted.

test/parameter_indexing_proxy_test.jl

Lines changed: 0 additions & 32 deletions
This file was deleted.

test/parameter_indexing_test.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using SymbolicIndexingInterface
2+
using Symbolics
3+
4+
struct FakeIntegrator{P}
5+
p::P
6+
end
7+
8+
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
9+
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
10+
11+
@variables a[1:2] b
12+
sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t])
13+
p = [1.0, 2.0, 3.0]
14+
fi = FakeIntegrator(copy(p))
15+
for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))]
16+
get = getp(sys, sym)
17+
set! = setp(sys, sym)
18+
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
19+
@test get(fi) == true_value
20+
set!(fi, 0.5 .* i)
21+
@test get(fi) == 0.5 .* i
22+
set!(fi, true_value)
23+
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ end
1313
@testset "Fallback test" begin
1414
@time include("fallback_test.jl")
1515
end
16-
@testset "Parameter indexing proxy test" begin
17-
@time include("parameter_indexing_proxy_test.jl")
16+
@testset "Parameter indexing test" begin
17+
@time include("parameter_indexing_test.jl")
1818
end

0 commit comments

Comments
 (0)