Skip to content

Commit 124e387

Browse files
Merge pull request #63 from SciML/as/tuple-parameter-values
feat: support indexing Tuple parameters, add tests
2 parents ed4bce0 + 30b4759 commit 124e387

File tree

8 files changed

+149
-101
lines changed

8 files changed

+149
-101
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ authors = ["Aayush Sabharwal <[email protected]> and contributors"]
44
version = "0.3.13"
55

66
[deps]
7+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1011
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1112

1213
[compat]
14+
Accessors = "0.1.36"
1315
Aqua = "0.8"
1416
ArrayInterface = "7.9"
1517
MacroTools = "0.5.13"

src/SymbolicIndexingInterface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import MacroTools
44
using RuntimeGeneratedFunctions
55
import StaticArraysCore: MArray, similar_type
66
import ArrayInterface
7+
using Accessors: @reset
78

89
RuntimeGeneratedFunctions.init(@__MODULE__)
910

@@ -22,8 +23,9 @@ include("interface.jl")
2223
export SymbolCache
2324
include("symbol_cache.jl")
2425

25-
export parameter_values, set_parameter!, parameter_values_at_time,
26-
parameter_values_at_state_time, parameter_timeseries, getp, setp
26+
export parameter_values, set_parameter!, finalize_parameters_hook!,
27+
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
28+
setp
2729
include("parameter_indexing.jl")
2830

2931
export state_values, set_state!, current_time, getu, setu

src/parameter_indexing.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The
77
two-argument version of this function will default to returning
88
`parameter_values(p)[i]`.
99
10-
If this function is called with an `AbstractArray`, it will return the same array.
10+
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
11+
array/tuple.
1112
"""
1213
function parameter_values end
1314

1415
parameter_values(arr::AbstractArray) = arr
16+
parameter_values(arr::Tuple) = arr
1517
parameter_values(arr::AbstractArray, i) = arr[i]
18+
parameter_values(arr::Tuple, i) = arr[i]
1619
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
1720

1821
"""
@@ -77,7 +80,8 @@ See: [`parameter_values`](@ref)
7780
"""
7881
function set_parameter! end
7982

80-
function set_parameter!(sys::AbstractArray, val, idx)
83+
# Tuple only included for the error message
84+
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
8185
sys[idx] = val
8286
end
8387
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

src/remake.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ are symbolic variables whose index in the buffer is determined using `sys`. The
66
values in `vals` may not match the types of values stored at the corresponding indexes in
77
the buffer, in which case the type of the buffer should be promoted accordingly. In
88
general, this method should attempt to preserve the types of values stored in `vals` as
9-
much as possible. The returned buffer should be of the same type (ignoring type-parameters)
10-
as `oldbuffer`.
9+
much as possible. Types can be promoted for type-stability, to maintain performance. The
10+
returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`.
1111
1212
This method is already implemented for
1313
`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
@@ -19,14 +19,30 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
1919
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
2020
elT = Union{}
2121
for val in values(vals)
22-
elT = Union{elT, typeof(val)}
22+
elT = promote_type(elT, typeof(val))
2323
end
2424

2525
newbuffer = similar(oldbuffer, elT)
26-
setu(sys, keys(vals))(newbuffer, values(vals))
26+
setu(sys, collect(keys(vals)))(newbuffer, elT.(values(vals)))
2727
else
2828
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
2929
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
3030
end
3131
return newbuffer
3232
end
33+
34+
mutable struct TupleRemakeWrapper
35+
t::Tuple
36+
end
37+
38+
function set_parameter!(sys::TupleRemakeWrapper, val, idx)
39+
tp = sys.t
40+
@reset tp[idx] = val
41+
sys.t = tp
42+
end
43+
44+
function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
45+
wrap = TupleRemakeWrapper(oldbuffer)
46+
setu(sys, collect(keys(vals)))(wrap, values(vals))
47+
return wrap.t
48+
end

src/trait.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ symbolic_type(::Type{Expr}) = ScalarSymbolic()
4747
"""
4848
hasname(x)
4949
50-
Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`.
50+
Check whether the given symbolic variable (for which `symbolic_type(x) != NotSymbolic()`) has a valid name as per `getname`. Defaults to `true` for `x::Symbol`.
5151
"""
5252
function hasname end
5353

@@ -57,9 +57,11 @@ hasname(::Any) = false
5757
"""
5858
getname(x)::Symbol
5959
60-
Get the name of a symbolic variable as a `Symbol`
60+
Get the name of a symbolic variable as a `Symbol`. Acts as the identity function for
61+
`x::Symbol`.
6162
"""
6263
function getname end
64+
getname(x::Symbol) = x
6365

6466
"""
6567
symbolic_evaluate(expr, syms::Dict; kwargs...)

test/parameter_indexing_test.jl

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator,
1717
end
1818

1919
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
20-
p = [1.0, 2.0, 3.0]
21-
fi = FakeIntegrator(sys, copy(p), Ref(0))
22-
new_p = [4.0, 5.0, 6.0]
23-
@test parameter_timeseries(fi) == [0]
24-
for (sym, oldval, newval, check_inference) in [
25-
(:a, p[1], new_p[1], true),
26-
(1, p[1], new_p[1], true),
27-
([:a, :b], p[1:2], new_p[1:2], true),
28-
(1:2, p[1:2], new_p[1:2], true),
29-
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
30-
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
31-
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
32-
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
33-
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
34-
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
35-
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
36-
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
37-
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
38-
]
39-
get = getp(sys, sym)
40-
set! = setp(sys, sym)
41-
if check_inference
42-
@inferred get(fi)
43-
end
44-
@test get(fi) == fi.ps[sym]
45-
@test get(fi) == oldval
46-
@test fi.counter[] == 0
47-
if check_inference
48-
@inferred set!(fi, newval)
49-
else
50-
set!(fi, newval)
51-
end
52-
@test fi.counter[] == 1
20+
for pType in [Vector, Tuple]
21+
p = [1.0, 2.0, 3.0]
22+
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
23+
new_p = [4.0, 5.0, 6.0]
24+
@test parameter_timeseries(fi) == [0]
25+
for (sym, oldval, newval, check_inference) in [
26+
(:a, p[1], new_p[1], true),
27+
(1, p[1], new_p[1], true),
28+
([:a, :b], p[1:2], new_p[1:2], true),
29+
(1:2, p[1:2], new_p[1:2], true),
30+
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
31+
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
32+
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
33+
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
34+
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
35+
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
36+
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
37+
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
38+
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
39+
]
40+
get = getp(sys, sym)
41+
set! = setp(sys, sym)
42+
if check_inference
43+
@inferred get(fi)
44+
end
45+
@test get(fi) == fi.ps[sym]
46+
@test get(fi) == oldval
5347

54-
@test get(fi) == newval
55-
set!(fi, oldval)
56-
@test get(fi) == oldval
57-
@test fi.counter[] == 2
48+
if pType === Tuple
49+
@test_throws MethodError set!(fi, newval)
50+
continue
51+
end
5852

59-
fi.ps[sym] = newval
60-
@test get(fi) == newval
61-
@test fi.counter[] == 3
62-
fi.ps[sym] = oldval
63-
@test get(fi) == oldval
64-
@test fi.counter[] == 4
53+
@test fi.counter[] == 0
54+
if check_inference
55+
@inferred set!(fi, newval)
56+
else
57+
set!(fi, newval)
58+
end
59+
@test fi.counter[] == 1
6560

66-
if check_inference
67-
@inferred get(p)
68-
end
69-
@test get(p) == oldval
70-
if check_inference
71-
@inferred set!(p, newval)
72-
else
73-
set!(p, newval)
61+
@test get(fi) == newval
62+
set!(fi, oldval)
63+
@test get(fi) == oldval
64+
@test fi.counter[] == 2
65+
66+
fi.ps[sym] = newval
67+
@test get(fi) == newval
68+
@test fi.counter[] == 3
69+
fi.ps[sym] = oldval
70+
@test get(fi) == oldval
71+
@test fi.counter[] == 4
72+
73+
if check_inference
74+
@inferred get(p)
75+
end
76+
@test get(p) == oldval
77+
if check_inference
78+
@inferred set!(p, newval)
79+
else
80+
set!(p, newval)
81+
end
82+
@test get(p) == newval
83+
set!(p, oldval)
84+
@test get(p) == oldval
85+
@test fi.counter[] == 4
86+
fi.counter[] = 0
7487
end
75-
@test get(p) == newval
76-
set!(p, oldval)
77-
@test get(p) == oldval
78-
@test fi.counter[] == 4
79-
fi.counter[] = 0
80-
end
8188

82-
for (sym, val) in [
83-
([:a, :b, :c], p),
84-
([:c, :a], p[[3, 1]]),
85-
((:b, :a), p[[2, 1]]),
86-
((1, :c), p[[1, 3]])
87-
]
88-
buffer = zeros(length(sym))
89-
get = getp(sys, sym)
90-
@inferred get(buffer, fi)
91-
@test buffer == val
89+
for (sym, val) in [
90+
([:a, :b, :c], p),
91+
([:c, :a], p[[3, 1]]),
92+
((:b, :a), p[[2, 1]]),
93+
((1, :c), p[[1, 3]])
94+
]
95+
buffer = zeros(length(sym))
96+
get = getp(sys, sym)
97+
@inferred get(buffer, fi)
98+
@test buffer == val
99+
end
92100
end
93101

94102
struct FakeSolution

test/remake_test.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,23 @@ using StaticArrays
44
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
55

66
for (buf, newbuf, newvals) in [
7-
# standard operation
8-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
9-
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
10-
# type "demotion"
11-
([1.0, 2.0, 3.0], [2, 3, 4],
12-
Dict(:x => 2, :y => 3, :z => 4))
13-
# type promotion
14-
([1, 2, 3], [2.0, 3.0, 4.0],
15-
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
16-
# union
17-
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
18-
Dict(:x => 2, :y => 3.0, :z => 4.0))
19-
# standard operation
20-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
21-
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
22-
# type "demotion"
23-
([1.0, 2.0, 3.0], [2, 3, 4],
24-
Dict(:a => 2, :b => 3, :c => 4))
25-
# type promotion
26-
([1, 2, 3], [2.0, 3.0, 4.0],
27-
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
28-
# union
29-
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
30-
Dict(:a => 2, :b => 3.0, :c => 4.0))]
7+
# standard operation
8+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
9+
# buffer type "demotion"
10+
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:x => 2, :y => 3, :z => 4)),
11+
# buffer type promotion
12+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
13+
# value type promotion
14+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
15+
# standard operation
16+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
17+
# buffer type "demotion"
18+
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:a => 2, :b => 3, :c => 4)),
19+
# buffer type promotion
20+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
21+
# value type promotion
22+
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
23+
]
3124
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
3225
buf = arrType(buf)
3326
newbuf = arrType(newbuf)
@@ -38,3 +31,19 @@ for (buf, newbuf, newvals) in [
3831
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
3932
end
4033
end
34+
35+
# Tuples not allowed for state
36+
for (buf, newbuf, newvals) in [
37+
# standard operation
38+
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
39+
# buffer type "demotion"
40+
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
41+
# buffer type promotion
42+
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
43+
# value type promotion
44+
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
45+
]
46+
_newbuf = remake_buffer(sys, buf, newvals)
47+
@test newbuf == _newbuf # test values
48+
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
49+
end

test/trait_test.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@ using Test
44
@test all(symbolic_type.([Int, Float64, String, Bool, UInt, Complex{Float64}]) .==
55
(NotSymbolic(),))
66
@test symbolic_type(Symbol) == ScalarSymbolic()
7+
@test hasname(:x)
8+
@test getname(:x) == :x
9+
@test !hasname(1)
10+
@test !hasname(1.0)
11+
@test !hasname("x")

0 commit comments

Comments
 (0)