Skip to content

Commit 7bf1d52

Browse files
Merge pull request #2531 from MasonProtter/patch-1
Add `getindex`/`setindex!` methods for `MTKParameters` with `ParameterIndex`
2 parents 422bd39 + 82c419d commit 7bf1d52

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/systems/parameter_buffer.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,33 @@ function Base.setindex!(p::MTKParameters, val, i)
363363
end
364364
end
365365

366+
function Base.getindex(p::MTKParameters, pind::ParameterIndex)
367+
(; portion, idx) = pind
368+
i, j, k... = idx
369+
if isempty(k)
370+
indexer = (v) -> v[i][j]
371+
else
372+
indexer = (v) -> v[i][j][k...]
373+
end
374+
if portion isa SciMLStructures.Tunable
375+
indexer(p.tunable)
376+
elseif portion isa SciMLStructures.Discrete
377+
indexer(p.discrete)
378+
elseif portion isa SciMLStructures.Constants
379+
indexer(p.constant)
380+
elseif portion === DEPENDENT_PORTION
381+
indexer(p.dependent)
382+
elseif portion === NONNUMERIC_PORTION
383+
indexer(p.nonnumeric)
384+
else
385+
error("Unhandled portion ", portion)
386+
end
387+
end
388+
389+
function Base.setindex!(p::MTKParameters, val, pind::ParameterIndex)
390+
SymbolicIndexingInterface.set_parameter!(p, val, pind)
391+
end
392+
366393
function Base.iterate(buf::MTKParameters, state = 1)
367394
total_len = 0
368395
total_len += _num_subarrays(buf.tunable)

test/split_parameters.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
44
using ModelingToolkit: t_nounits as t, D_nounits as D
5+
using ModelingToolkit: MTKParameters, ParameterIndex, DEPENDENT_PORTION, NONNUMERIC_PORTION
6+
using SciMLStructures: Tunable, Discrete, Constants
57

68
x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]
79

@@ -189,3 +191,25 @@ connections = [[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
189191
connect(add.output, :u, model.torque.tau)]
190192
@named closed_loop = ODESystem(connections, t, systems = [model, state_feedback, add, d])
191193
S = get_sensitivity(closed_loop, :u)
194+
195+
@testset "Indexing MTKParameters with ParameterIndex" begin
196+
ps = MTKParameters(([1.0, 2.0], [3, 4]),
197+
([true, false], [[1 2; 3 4]]),
198+
([5, 6],),
199+
([7.0, 8.0],),
200+
(["hi", "bye"], [:lie, :die]),
201+
nothing,
202+
nothing)
203+
@test ps[ParameterIndex(Tunable(), (1, 2))] === 2.0
204+
@test ps[ParameterIndex(Tunable(), (2, 2))] === 4
205+
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 4
206+
@test ps[ParameterIndex(Discrete(), (2, 1))] == [1 2; 3 4]
207+
@test ps[ParameterIndex(Constants(), (1, 1))] === 5
208+
@test ps[ParameterIndex(DEPENDENT_PORTION, (1, 1))] === 7.0
209+
@test ps[ParameterIndex(NONNUMERIC_PORTION, (2, 2))] === :die
210+
211+
ps[ParameterIndex(Tunable(), (1, 2))] = 3.0
212+
ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] = 5
213+
@test ps[ParameterIndex(Tunable(), (1, 2))] === 3.0
214+
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 5
215+
end

0 commit comments

Comments
 (0)