Skip to content

Commit f36e889

Browse files
Merge pull request #3069 from AayushSabharwal/as/tunable-order
feat: add `reorder_dimension_by_tunables!`, ensure `tunable_parameters` is correctly ordered
2 parents 949eca7 + da612c0 commit f36e889

File tree

11 files changed

+214
-10
lines changed

11 files changed

+214
-10
lines changed

docs/src/basics/FAQ.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,62 @@ D = Differential(x)
244244
@variables y(x)
245245
@named sys = ODESystem([D(y) ~ x], x)
246246
```
247+
248+
## Ordering of tunable parameters
249+
250+
Tunable parameters are floating point parameters, not used in callbacks and not marked with `tunable = false` in their metadata. These are expected to be used with AD
251+
and optimization libraries. As such, they are stored together in one `Vector{T}`. To obtain the ordering of tunable parameters in this buffer, use:
252+
253+
```@docs
254+
tunable_parameters
255+
```
256+
257+
If you have an array in which a particular dimension is in the order of tunable parameters (e.g. the jacobian with respect to tunables) then that dimension of the
258+
array can be reordered into the required permutation using the symbolic variables:
259+
260+
```@docs
261+
reorder_dimension_by_tunables!
262+
reorder_dimension_by_tunables
263+
```
264+
265+
For example:
266+
267+
```@example reorder
268+
using ModelingToolkit
269+
270+
@parameters p q[1:3] r[1:2, 1:2]
271+
272+
@named sys = ODESystem(Equation[], ModelingToolkit.t_nounits, [], [p, q, r])
273+
sys = complete(sys)
274+
```
275+
276+
The canonicalized tunables portion of `MTKParameters` will be in the order of tunables:
277+
278+
```@example reorder
279+
using SciMLStructures: canonicalize, Tunable
280+
281+
ps = MTKParameters(sys, [p => 1.0, q => [2.0, 3.0, 4.0], r => [5.0 6.0; 7.0 8.0]])
282+
arr = canonicalize(Tunable(), ps)[1]
283+
```
284+
285+
We can reorder this to contain the value for `p`, then all values for `q`, then for `r` using:
286+
287+
```@example reorder
288+
reorder_dimension_by_tunables(sys, arr, [p, q, r])
289+
```
290+
291+
This also works with interleaved subarrays of symbolics:
292+
293+
```@example reorder
294+
reorder_dimension_by_tunables(sys, arr, [q[1], r[1, :], q[2], r[2, :], q[3], p])
295+
```
296+
297+
And arbitrary dimensions of higher dimensional arrays:
298+
299+
```@example reorder
300+
highdimarr = stack([i * arr for i in 1:5]; dims = 1)
301+
```
302+
303+
```@example reorder
304+
reorder_dimension_by_tunables(sys, highdimarr, [q[1:2], r[1, :], q[3], r[2, :], p]; dim = 2)
305+
```

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,6 @@ export debug_system
278278
export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime
279279
export Clock, SolverStepClock, TimeDomain
280280

281-
export MTKParameters
281+
export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables
282282

283283
end # module

src/systems/abstractsystem.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,44 @@ namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
882882
function complete(sys::AbstractSystem; split = true)
883883
if split && has_index_cache(sys)
884884
@set! sys.index_cache = IndexCache(sys)
885+
all_ps = parameters(sys)
886+
if !isempty(all_ps)
887+
# reorder parameters by portions
888+
ps_split = reorder_parameters(sys, all_ps)
889+
# if there are no tunables, vcat them
890+
if isempty(get_index_cache(sys).tunable_idx)
891+
ordered_ps = reduce(vcat, ps_split)
892+
else
893+
# if there are tunables, they will all be in `ps_split[1]`
894+
# and the arrays will have been scalarized
895+
ordered_ps = eltype(all_ps)[]
896+
i = 1
897+
# go through all the tunables
898+
while i <= length(ps_split[1])
899+
sym = ps_split[1][i]
900+
# if the sym is not a scalarized array symbolic OR it was already scalarized,
901+
# just push it as-is
902+
if !iscall(sym) || operation(sym) != getindex ||
903+
any(isequal(sym), all_ps)
904+
push!(ordered_ps, sym)
905+
i += 1
906+
continue
907+
end
908+
# the next `length(sym)` symbols should be scalarized versions of the same
909+
# array symbolic
910+
if !allequal(first(arguments(x))
911+
for x in view(ps_split[1], i:(i + length(sym) - 1)))
912+
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
913+
end
914+
arrsym = first(arguments(sym))
915+
push!(ordered_ps, arrsym)
916+
i += length(arrsym)
917+
end
918+
ordered_ps = vcat(
919+
ordered_ps, reduce(vcat, ps_split[2:end]; init = eltype(ordered_ps)[]))
920+
end
921+
@set! sys.ps = ordered_ps
922+
end
885923
end
886924
if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing
887925
@set! sys.initializesystem = complete(get_initializesystem(sys); split)

src/systems/index_cache.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,55 @@ end
505505
fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
506506
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}
507507
fntype_to_function_type(::Type{FnType{A, R}}) where {A, R} = FunctionWrapper{R, A}
508+
509+
"""
510+
reorder_dimension_by_tunables!(dest::AbstractArray, sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
511+
512+
Assuming the order of values in dimension `dim` of `arr` correspond to the order of tunable
513+
parameters in the system, reorder them according to the order described in `syms`. `syms` must
514+
be a permutation of `tunable_parameters(sys)`. The result is written to `dest`. The `size` of `dest` and
515+
`arr` must be equal. Return `dest`.
516+
517+
See also: [`MTKParameters`](@ref), [`tunable_parameters`](@ref), [`reorder_dimension_by_tunables`](@ref).
518+
"""
519+
function reorder_dimension_by_tunables!(
520+
dest::AbstractArray, sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
521+
if !iscomplete(sys)
522+
throw(ArgumentError("A completed system is required. Call `complete` or `structural_simplify` on the system."))
523+
end
524+
if !has_index_cache(sys) || (ic = get_index_cache(sys)) === nothing
525+
throw(ArgumentError("The system does not have an index cache. Call `complete` or `structural_simplify` on the system with `split = true`."))
526+
end
527+
if size(dest) != size(arr)
528+
throw(ArgumentError("Source and destination arrays must have the same size. Got source array with size $(size(arr)) and destination with size $(size(dest))."))
529+
end
530+
531+
dsti = 1
532+
for sym in syms
533+
idx = parameter_index(ic, sym)
534+
if !(idx.portion isa SciMLStructures.Tunable)
535+
throw(ArgumentError("`syms` must be a permutation of `tunable_parameters(sys)`. Found $sym which is not a tunable parameter."))
536+
end
537+
538+
dstidx = ntuple(
539+
i -> i == dim ? (dsti:(dsti + length(sym) - 1)) : (:), Val(ndims(arr)))
540+
destv = @view dest[dstidx...]
541+
dsti += length(sym)
542+
arridx = ntuple(i -> i == dim ? (idx.idx) : (:), Val(ndims(arr)))
543+
srcv = @view arr[arridx...]
544+
copyto!(destv, srcv)
545+
end
546+
return dest
547+
end
548+
549+
"""
550+
reorder_dimension_by_tunables(sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
551+
552+
Out-of-place version of [`reorder_dimension_by_tunables!`](@ref).
553+
"""
554+
function reorder_dimension_by_tunables(
555+
sys::AbstractSystem, arr::AbstractArray, syms; dim = 1)
556+
buffer = similar(arr)
557+
reorder_dimension_by_tunables!(buffer, sys, arr, syms; dim)
558+
return buffer
559+
end

src/variables.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,12 @@ Create a tunable parameter by
381381
@parameters u [tunable=true]
382382
```
383383
384-
See also [`getbounds`](@ref), [`istunable`](@ref)
384+
For systems created with `split = true` (the default) and `default = true` passed to this function, the order
385+
of parameters returned is the order in which they are stored in the tunables portion of `MTKParameters`. Note
386+
that array variables will not be scalarized. To obtain the flattened representation of the tunables portion,
387+
call `Symbolics.scalarize(tunable_parameters(sys))` and concatenate the resulting arrays.
388+
389+
See also [`getbounds`](@ref), [`istunable`](@ref), [`MTKParameters`](@ref), [`complete`](@ref)
385390
"""
386391
function tunable_parameters(sys, p = parameters(sys); default = true)
387392
filter(x -> istunable(x, default), p)

test/discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ for df in [
4242
# iip
4343
du = zeros(3)
4444
u = collect(1:3)
45-
p = MTKParameters(syss, parameters(syss) .=> collect(1:5))
45+
p = MTKParameters(syss, [c, nsteps, δt, β, γ] .=> collect(1:5))
4646
df.f(du, u, p, 0)
4747
@test du [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
4848

test/index_cache.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SymbolicIndexingInterface
1+
using ModelingToolkit, SymbolicIndexingInterface, SciMLStructures
22
using ModelingToolkit: t_nounits as t
33

44
# Ensure indexes of array symbolics are cached appropriately
@@ -43,3 +43,52 @@ ic = ModelingToolkit.get_index_cache(sys)
4343
@test isequal(ic.symbol_to_variable[:x], x)
4444
@test isequal(ic.symbol_to_variable[:y], y)
4545
@test isequal(ic.symbol_to_variable[:z], z)
46+
47+
@testset "tunable_parameters is ordered" begin
48+
@parameters p q[1:3] r[1:2, 1:2] s [tunable = false]
49+
@named sys = ODESystem(Equation[], t, [], [p, q, r, s])
50+
sys = complete(sys)
51+
@test all(splat(isequal), zip(tunable_parameters(sys), parameters(sys)[1:3]))
52+
53+
offset = 1
54+
for par in tunable_parameters(sys)
55+
idx = parameter_index(sys, par)
56+
@test idx.portion isa SciMLStructures.Tunable
57+
if Symbolics.isarraysymbolic(par)
58+
@test vec(idx.idx) == offset:(offset + length(par) - 1)
59+
else
60+
@test idx.idx == offset
61+
end
62+
offset += length(par)
63+
end
64+
end
65+
66+
@testset "reorder_dimension_by_tunables" begin
67+
@parameters p q[1:3] r[1:2, 1:2] s [tunable = false]
68+
@named sys = ODESystem(Equation[], t, [], [p, q, r, s])
69+
src = ones(8)
70+
dst = zeros(8)
71+
# system must be complete...
72+
@test_throws ArgumentError reorder_dimension_by_tunables!(dst, sys, src, [p, q, r])
73+
@test_throws ArgumentError reorder_dimension_by_tunables(sys, src, [p, q, r])
74+
sys = complete(sys; split = false)
75+
# with split = true...
76+
@test_throws ArgumentError reorder_dimension_by_tunables!(dst, sys, src, [p, q, r])
77+
@test_throws ArgumentError reorder_dimension_by_tunables(sys, src, [p, q, r])
78+
sys = complete(sys)
79+
# and the arrays must have matching size
80+
@test_throws ArgumentError reorder_dimension_by_tunables!(
81+
zeros(2, 4), sys, src, [p, q, r])
82+
83+
ps = MTKParameters(sys, [p => 1.0, q => 3ones(3), r => 4ones(2, 2), s => 0.0])
84+
src = ps.tunable
85+
reorder_dimension_by_tunables!(dst, sys, src, [q, r, p])
86+
@test dst vcat(3ones(3), 4ones(4), 1.0)
87+
@test reorder_dimension_by_tunables(sys, src, [r, p, q]) vcat(4ones(4), 1.0, 3ones(3))
88+
reorder_dimension_by_tunables!(dst, sys, src, [q[1], r[:, 1], q[2], r[:, 2], q[3], p])
89+
@test dst vcat(3.0, 4ones(2), 3.0, 4ones(2), 3.0, 1.0)
90+
src = stack([copy(ps.tunable) for i in 1:5]; dims = 1)
91+
dst = zeros(size(src))
92+
reorder_dimension_by_tunables!(dst, sys, src, [r, q, p]; dim = 2)
93+
@test dst stack([vcat(4ones(4), 3ones(3), 1.0) for i in 1:5]; dims = 1)
94+
end

test/labelledarrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,6 @@ u0 = @LArray [9998.0, 1.0, 1.0, 1.0] (:S, :I, :R, :C)
8787
problem = ODEProblem(SIR!, u0, tspan, p)
8888
sys = complete(modelingtoolkitize(problem))
8989

90-
@test all(isequal.(parameters(sys), getproperty.(@variables(β, η, ω, φ, σ, μ), :val)))
90+
@test all(any(isequal(x), parameters(sys))
91+
for x in ModelingToolkit.unwrap.(@variables(β, η, ω, φ, σ, μ)))
9192
@test all(isequal.(Symbol.(unknowns(sys)), Symbol.(@variables(S(t), I(t), R(t), C(t)))))

test/model_parsing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,9 @@ end
534534
@named else_in_sys = InsideTheBlock(flag = 3)
535535
else_in_sys = complete(else_in_sys)
536536

537-
@test getname.(parameters(if_in_sys)) == [:if_parameter, :eq]
538-
@test getname.(parameters(elseif_in_sys)) == [:elseif_parameter, :eq]
539-
@test getname.(parameters(else_in_sys)) == [:else_parameter, :eq]
537+
@test sort(getname.(parameters(if_in_sys))) == [:eq, :if_parameter]
538+
@test sort(getname.(parameters(elseif_in_sys))) == [:elseif_parameter, :eq]
539+
@test sort(getname.(parameters(else_in_sys))) == [:else_parameter, :eq]
540540

541541
@test getdefault(if_in_sys.if_parameter) == 100
542542
@test getdefault(elseif_in_sys.elseif_parameter) == 101

test/precompile_test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using ODEPrecompileTest
1010

1111
u = collect(1:3)
1212
p = ModelingToolkit.MTKParameters(ODEPrecompileTest.f_noeval_good.sys,
13-
parameters(ODEPrecompileTest.f_noeval_good.sys) .=> collect(4:6))
13+
[, , ] .=> collect(4:6))
1414

1515
# These cases do not work, because they get defined in the ModelingToolkit's RGF cache.
1616
@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit

0 commit comments

Comments
 (0)