Skip to content

Commit de8b19b

Browse files
feat: support indexing Tuple parameters, add tests
1 parent ed4bce0 commit de8b19b

File tree

2 files changed

+81
-69
lines changed

2 files changed

+81
-69
lines changed

src/parameter_indexing.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The
77
two-argument version of this function will default to returning
88
`parameter_values(p)[i]`.
99
10-
If this function is called with an `AbstractArray`, it will return the same array.
10+
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
11+
array/tuple.
1112
"""
1213
function parameter_values end
1314

1415
parameter_values(arr::AbstractArray) = arr
16+
parameter_values(arr::Tuple) = arr
1517
parameter_values(arr::AbstractArray, i) = arr[i]
18+
parameter_values(arr::Tuple, i) = arr[i]
1619
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
1720

1821
"""
@@ -77,7 +80,8 @@ See: [`parameter_values`](@ref)
7780
"""
7881
function set_parameter! end
7982

80-
function set_parameter!(sys::AbstractArray, val, idx)
83+
# Tuple only included for the error message
84+
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
8185
sys[idx] = val
8286
end
8387
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

test/parameter_indexing_test.jl

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator,
1717
end
1818

1919
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
20-
p = [1.0, 2.0, 3.0]
21-
fi = FakeIntegrator(sys, copy(p), Ref(0))
22-
new_p = [4.0, 5.0, 6.0]
23-
@test parameter_timeseries(fi) == [0]
24-
for (sym, oldval, newval, check_inference) in [
25-
(:a, p[1], new_p[1], true),
26-
(1, p[1], new_p[1], true),
27-
([:a, :b], p[1:2], new_p[1:2], true),
28-
(1:2, p[1:2], new_p[1:2], true),
29-
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
30-
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
31-
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
32-
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
33-
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
34-
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
35-
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
36-
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
37-
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
38-
]
39-
get = getp(sys, sym)
40-
set! = setp(sys, sym)
41-
if check_inference
42-
@inferred get(fi)
43-
end
44-
@test get(fi) == fi.ps[sym]
45-
@test get(fi) == oldval
46-
@test fi.counter[] == 0
47-
if check_inference
48-
@inferred set!(fi, newval)
49-
else
50-
set!(fi, newval)
51-
end
52-
@test fi.counter[] == 1
20+
for pType in [Vector, Tuple]
21+
p = [1.0, 2.0, 3.0]
22+
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
23+
new_p = [4.0, 5.0, 6.0]
24+
@test parameter_timeseries(fi) == [0]
25+
for (sym, oldval, newval, check_inference) in [
26+
(:a, p[1], new_p[1], true),
27+
(1, p[1], new_p[1], true),
28+
([:a, :b], p[1:2], new_p[1:2], true),
29+
(1:2, p[1:2], new_p[1:2], true),
30+
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
31+
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
32+
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
33+
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
34+
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
35+
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
36+
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
37+
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
38+
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
39+
]
40+
get = getp(sys, sym)
41+
set! = setp(sys, sym)
42+
if check_inference
43+
@inferred get(fi)
44+
end
45+
@test get(fi) == fi.ps[sym]
46+
@test get(fi) == oldval
5347

54-
@test get(fi) == newval
55-
set!(fi, oldval)
56-
@test get(fi) == oldval
57-
@test fi.counter[] == 2
48+
if pType === Tuple
49+
@test_throws MethodError set!(fi, newval)
50+
continue
51+
end
5852

59-
fi.ps[sym] = newval
60-
@test get(fi) == newval
61-
@test fi.counter[] == 3
62-
fi.ps[sym] = oldval
63-
@test get(fi) == oldval
64-
@test fi.counter[] == 4
53+
@test fi.counter[] == 0
54+
if check_inference
55+
@inferred set!(fi, newval)
56+
else
57+
set!(fi, newval)
58+
end
59+
@test fi.counter[] == 1
6560

66-
if check_inference
67-
@inferred get(p)
68-
end
69-
@test get(p) == oldval
70-
if check_inference
71-
@inferred set!(p, newval)
72-
else
73-
set!(p, newval)
61+
@test get(fi) == newval
62+
set!(fi, oldval)
63+
@test get(fi) == oldval
64+
@test fi.counter[] == 2
65+
66+
fi.ps[sym] = newval
67+
@test get(fi) == newval
68+
@test fi.counter[] == 3
69+
fi.ps[sym] = oldval
70+
@test get(fi) == oldval
71+
@test fi.counter[] == 4
72+
73+
if check_inference
74+
@inferred get(p)
75+
end
76+
@test get(p) == oldval
77+
if check_inference
78+
@inferred set!(p, newval)
79+
else
80+
set!(p, newval)
81+
end
82+
@test get(p) == newval
83+
set!(p, oldval)
84+
@test get(p) == oldval
85+
@test fi.counter[] == 4
86+
fi.counter[] = 0
7487
end
75-
@test get(p) == newval
76-
set!(p, oldval)
77-
@test get(p) == oldval
78-
@test fi.counter[] == 4
79-
fi.counter[] = 0
80-
end
8188

82-
for (sym, val) in [
83-
([:a, :b, :c], p),
84-
([:c, :a], p[[3, 1]]),
85-
((:b, :a), p[[2, 1]]),
86-
((1, :c), p[[1, 3]])
87-
]
88-
buffer = zeros(length(sym))
89-
get = getp(sys, sym)
90-
@inferred get(buffer, fi)
91-
@test buffer == val
89+
for (sym, val) in [
90+
([:a, :b, :c], p),
91+
([:c, :a], p[[3, 1]]),
92+
((:b, :a), p[[2, 1]]),
93+
((1, :c), p[[1, 3]])
94+
]
95+
buffer = zeros(length(sym))
96+
get = getp(sys, sym)
97+
@inferred get(buffer, fi)
98+
@test buffer == val
99+
end
92100
end
93101

94102
struct FakeSolution

0 commit comments

Comments
 (0)