Skip to content

Commit 7f3b1e1

Browse files
fix: properly handle initial parameters in complete
1 parent e2251af commit 7f3b1e1

File tree

1 file changed

+46
-30
lines changed

1 file changed

+46
-30
lines changed

src/systems/abstractsystem.jl

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -773,38 +773,21 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
773773
if !isempty(all_ps)
774774
# reorder parameters by portions
775775
ps_split = reorder_parameters(sys, all_ps)
776+
# if there are tunables, they will all be in `ps_split[1]`
777+
# and the arrays will have been scalarized
778+
ordered_ps = eltype(all_ps)[]
776779
# if there are no tunables, vcat them
777-
if isempty(get_index_cache(sys).tunable_idx)
778-
ordered_ps = reduce(vcat, ps_split)
779-
else
780-
# if there are tunables, they will all be in `ps_split[1]`
781-
# and the arrays will have been scalarized
782-
ordered_ps = eltype(all_ps)[]
783-
i = 1
784-
# go through all the tunables
785-
while i <= length(ps_split[1])
786-
sym = ps_split[1][i]
787-
# if the sym is not a scalarized array symbolic OR it was already scalarized,
788-
# just push it as-is
789-
if !iscall(sym) || operation(sym) != getindex ||
790-
any(isequal(sym), all_ps)
791-
push!(ordered_ps, sym)
792-
i += 1
793-
continue
794-
end
795-
# the next `length(sym)` symbols should be scalarized versions of the same
796-
# array symbolic
797-
if !allequal(first(arguments(x))
798-
for x in view(ps_split[1], i:(i + length(sym) - 1)))
799-
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
800-
end
801-
arrsym = first(arguments(sym))
802-
push!(ordered_ps, arrsym)
803-
i += length(arrsym)
804-
end
805-
ordered_ps = vcat(
806-
ordered_ps, reduce(vcat, ps_split[2:end]; init = eltype(ordered_ps)[]))
780+
if !isempty(get_index_cache(sys).tunable_idx)
781+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
782+
ps_split = Base.tail(ps_split)
807783
end
784+
# unflatten initial parameters
785+
if !isempty(get_index_cache(sys).initials_idx)
786+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
787+
ps_split = Base.tail(ps_split)
788+
end
789+
ordered_ps = vcat(
790+
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
808791
@set! sys.ps = ordered_ps
809792
end
810793
elseif has_index_cache(sys)
@@ -816,6 +799,39 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
816799
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
817800
end
818801

802+
"""
803+
$(TYPEDSIGNATURES)
804+
805+
Given a flattened array of parameters `params` and a collection of all (unscalarized)
806+
parameters in the system `all_ps`, unscalarize the elements in `params` and append
807+
to `buffer` in the same order as they are present in `params`. Effectively, if
808+
`params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`.
809+
"""
810+
function unflatten_parameters!(buffer, params, all_ps)
811+
i = 1
812+
# go through all the tunables
813+
while i <= length(params)
814+
sym = params[i]
815+
# if the sym is not a scalarized array symbolic OR it was already scalarized,
816+
# just push it as-is
817+
if !iscall(sym) || operation(sym) != getindex ||
818+
any(isequal(sym), all_ps)
819+
push!(buffer, sym)
820+
i += 1
821+
continue
822+
end
823+
# the next `length(sym)` symbols should be scalarized versions of the same
824+
# array symbolic
825+
if !allequal(first(arguments(x))
826+
for x in view(params, i:(i + length(sym) - 1)))
827+
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
828+
end
829+
arrsym = first(arguments(sym))
830+
push!(buffer, arrsym)
831+
i += length(arrsym)
832+
end
833+
end
834+
819835
for prop in [:eqs
820836
:tag
821837
:noiseeqs

0 commit comments

Comments
 (0)