Skip to content

Commit f1b0036

Browse files
feat: add reorder_dimension_by_tunables!
1 parent 8f6828c commit f1b0036

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,6 @@ export debug_system
278278
export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime
279279
export Clock, SolverStepClock, TimeDomain
280280

281-
export MTKParameters
281+
export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables
282282

283283
end # module

src/systems/index_cache.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,55 @@ end
496496
fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
497497
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}
498498
fntype_to_function_type(::Type{FnType{A, R}}) where {A, R} = FunctionWrapper{R, A}
499+
500+
"""
501+
reorder_dimension_by_tunables!(dest::AbstractArray, sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
502+
503+
Assuming the order of values in dimension `dim` of `arr` correspond to the order of tunable
504+
parameters in the system, reorder them according to the order described in `syms`. `syms` must
505+
be a permutation of `tunable_parameters(sys)`. The result is written to `dest`. The `size` of `dest` and
506+
`arr` must be equal. Return `dest`.
507+
508+
See also: [`MTKParameters`](@ref), [`tunable_parameters`](@ref), [`reorder_dimension_by_tunables`](@ref).
509+
"""
510+
function reorder_dimension_by_tunables!(
511+
dest::AbstractArray, sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
512+
if !iscomplete(sys)
513+
throw(ArgumentError("A completed system is required. Call `complete` or `structural_simplify` on the system."))
514+
end
515+
if !has_index_cache(sys) || (ic = get_index_cache(sys)) === nothing
516+
throw(ArgumentError("The system does not have an index cache. Call `complete` or `structural_simplify` on the system with `split = true`."))
517+
end
518+
if size(dest) != size(arr)
519+
throw(ArgumentError("Source and destination arrays must have the same size. Got source array with size $(size(arr)) and destination with size $(size(dest))."))
520+
end
521+
522+
dsti = 1
523+
for sym in syms
524+
idx = parameter_index(ic, sym)
525+
if !(idx.portion isa SciMLStructures.Tunable)
526+
throw(ArgumentError("`syms` must be a permutation of `tunable_parameters(sys)`. Found $sym which is not a tunable parameter."))
527+
end
528+
529+
dstidx = ntuple(
530+
i -> i == dim ? (dsti:(dsti + length(sym) - 1)) : (:), Val(ndims(arr)))
531+
destv = @view dest[dstidx...]
532+
dsti += length(sym)
533+
arridx = ntuple(i -> i == dim ? (idx.idx) : (:), Val(ndims(arr)))
534+
srcv = @view arr[arridx...]
535+
copyto!(destv, srcv)
536+
end
537+
return dest
538+
end
539+
540+
"""
541+
reorder_dimension_by_tunables(sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
542+
543+
Out-of-place version of [`reorder_dimension_by_tunables!`](@ref).
544+
"""
545+
function reorder_dimension_by_tunables(
546+
sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
547+
buffer = similar(arr)
548+
reorder_dimension_by_tunables!(buffer, sys, arr, syms; dim)
549+
return buffer
550+
end

test/index_cache.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,32 @@ ic = ModelingToolkit.get_index_cache(sys)
6363
end
6464
end
6565

66+
@testset "reorder_dimension_by_tunables" begin
67+
@parameters p q[1:3] r[1:2, 1:2] s [tunable = false]
68+
@named sys = ODESystem(Equation[], t, [], [p, q, r, s])
69+
src = ones(8)
70+
dst = zeros(8)
71+
# system must be complete...
72+
@test_throws ArgumentError reorder_dimension_by_tunables!(dst, sys, src, [p, q, r])
73+
@test_throws ArgumentError reorder_dimension_by_tunables(sys, src, [p, q, r])
74+
sys = complete(sys; split = false)
75+
# with split = true...
76+
@test_throws ArgumentError reorder_dimension_by_tunables!(dst, sys, src, [p, q, r])
77+
@test_throws ArgumentError reorder_dimension_by_tunables(sys, src, [p, q, r])
78+
sys = complete(sys)
79+
# and the arrays must have matching size
80+
@test_throws ArgumentError reorder_dimension_by_tunables!(
81+
zeros(2, 4), sys, src, [p, q, r])
82+
83+
ps = MTKParameters(sys, [p => 1.0, q => 3ones(3), r => 4ones(2, 2), s => 0.0])
84+
src = ps.tunable
85+
reorder_dimension_by_tunables!(dst, sys, src, [q, r, p])
86+
@test dst vcat(3ones(3), 4ones(4), 1.0)
87+
@test reorder_dimension_by_tunables(sys, src, [r, p, q]) vcat(4ones(4), 1.0, 3ones(3))
88+
reorder_dimension_by_tunables!(dst, sys, src, [q[1], r[:, 1], q[2], r[:, 2], q[3], p])
89+
@test dst vcat(3.0, 4ones(2), 3.0, 4ones(2), 3.0, 1.0)
90+
src = stack([copy(ps.tunable) for i in 1:5]; dims = 1)
91+
dst = zeros(size(src))
92+
reorder_dimension_by_tunables!(dst, sys, src, [r, q, p]; dim = 2)
93+
@test dst stack([vcat(4ones(4), 3ones(3), 1.0) for i in 1:5]; dims = 1)
94+
end

0 commit comments

Comments
 (0)