Skip to content

Commit 2ff166d

Browse files
Merge pull request #62 from SciML/as/remake-buffer
feat: add `remake_buffer`
2 parents 3473a4a + c6a1432 commit 2ff166d

File tree

8 files changed

+110
-1
lines changed

8 files changed

+110
-1
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,27 @@ authors = ["Aayush Sabharwal <[email protected]> and contributors"]
44
version = "0.3.11"
55

66
[deps]
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
78
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
89
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
10+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
911

1012
[compat]
1113
Aqua = "0.8"
14+
ArrayInterface = "7.9"
1215
MacroTools = "0.5.13"
1316
RuntimeGeneratedFunctions = "0.5"
1417
SafeTestsets = "0.0.1"
18+
StaticArrays = "1.9"
19+
StaticArraysCore = "1.4"
1520
Test = "1"
1621
julia = "1.10"
1722

1823
[extras]
1924
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2025
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
26+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2127
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2228

2329
[targets]
24-
test = ["Aqua", "Test", "SafeTestsets"]
30+
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ getu
5353
setu
5454
```
5555

56+
## Container objects
57+
58+
```@docs
59+
remake_buffer
60+
```
61+
5662
### Parameter timeseries
5763

5864
If a solution object saves a timeseries of parameter values that are updated during the

docs/src/complete_sii.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,9 @@ idxs = @show rand(Bool, 10) # boolean mask for indexing
452452
sol.ps[:a, idxs]
453453
```
454454

455+
## Custom containers
456+
457+
A custom container object (such as `ModelingToolkit.MTKParameters`) should implement
458+
[`remake_buffer`](@ref) to allow creating a new buffer with updated values, possibly
459+
with different types. This is already implemented for `AbstractArray`s (including static
460+
arrays).

docs/src/usage.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,13 @@ parameter_values(prob)
182182
on other problem/solution instances can be the key to achieving good performance. Note
183183
that this caching is allowed only when the symbolic system is unchanged (it's fine for
184184
the numerical values to have changed, but not the underlying equations).
185+
186+
## Re-creating a buffer
187+
188+
To re-create a buffer (of unknowns or parameters) use [`remake_buffer`](@ref). This allows
189+
changing the type of values in the buffer (for example, changing the value of a parameter
190+
from `Float64` to `Float32`).
191+
192+
```@example Usage
193+
remake_buffer(sys, prob.p, Dict(σ => 1f0, ρ => 2f0, β => 3f0))
194+
```

src/SymbolicIndexingInterface.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ module SymbolicIndexingInterface
22

33
import MacroTools
44
using RuntimeGeneratedFunctions
5+
import StaticArraysCore: MArray, similar_type
6+
import ArrayInterface
7+
58
RuntimeGeneratedFunctions.init(@__MODULE__)
69

710
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
@@ -28,4 +31,7 @@ include("state_indexing.jl")
2831

2932
export ParameterIndexingProxy
3033
include("parameter_indexing_proxy.jl")
34+
35+
export remake_buffer
36+
include("remake.jl")
3137
end

src/remake.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
remake_buffer(sys, oldbuffer, vals::Dict)
3+
4+
Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals`
5+
are symbolic variables whose index in the buffer is determined using `sys`. The types of
6+
values in `vals` may not match the types of values stored at the corresponding indexes in
7+
the buffer, in which case the type of the buffer should be promoted accordingly. In
8+
general, this method should attempt to preserve the types of values stored in `vals` as
9+
much as possible. The returned buffer should be of the same type (ignoring type-parameters)
10+
as `oldbuffer`.
11+
12+
This method is already implemented for
13+
`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
14+
as well.
15+
"""
16+
function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
17+
# similar when used with an `MArray` and nonconcrete eltype returns a
18+
# SizedArray. `similar_type` still returns an `MArray`
19+
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
20+
elT = Union{}
21+
for val in values(vals)
22+
elT = Union{elT, typeof(val)}
23+
end
24+
25+
newbuffer = similar(oldbuffer, elT)
26+
setu(sys, keys(vals))(newbuffer, values(vals))
27+
else
28+
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
29+
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
30+
end
31+
return newbuffer
32+
end

test/remake_test.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using SymbolicIndexingInterface
2+
using StaticArrays
3+
4+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
5+
6+
for (buf, newbuf, newvals) in [
7+
# standard operation
8+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
9+
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
10+
# type "demotion"
11+
([1.0, 2.0, 3.0], [2, 3, 4],
12+
Dict(:x => 2, :y => 3, :z => 4))
13+
# type promotion
14+
([1, 2, 3], [2.0, 3.0, 4.0],
15+
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
16+
# union
17+
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
18+
Dict(:x => 2, :y => 3.0, :z => 4.0))
19+
# standard operation
20+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
21+
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
22+
# type "demotion"
23+
([1.0, 2.0, 3.0], [2, 3, 4],
24+
Dict(:a => 2, :b => 3, :c => 4))
25+
# type promotion
26+
([1, 2, 3], [2.0, 3.0, 4.0],
27+
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
28+
# union
29+
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
30+
Dict(:a => 2, :b => 3.0, :c => 4.0))]
31+
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
32+
buf = arrType(buf)
33+
newbuf = arrType(newbuf)
34+
_newbuf = remake_buffer(sys, buf, newvals)
35+
36+
@test _newbuf != buf # should not alias
37+
@test newbuf == _newbuf # test values
38+
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
39+
end
40+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ end
2323
@safetestset "State indexing test" begin
2424
@time include("state_indexing_test.jl")
2525
end
26+
@safetestset "Remake test" begin
27+
@time include("remake_test.jl")
28+
end

0 commit comments

Comments
 (0)