Skip to content

Commit fb5eaa2

Browse files
Merge pull request #3540 from AayushSabharwal/as/ss-no-metadata
feat: reduce reliance on metadata in `structural_simplify`
2 parents b150fe2 + 4deb198 commit fb5eaa2

File tree

20 files changed

+233
-116
lines changed

20 files changed

+233
-116
lines changed

src/inputoutput.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
319319

320320
@set! sys.ps = [ps; new_parameters]
321321
@set! state.sys = sys
322-
@set! state.fullvars = new_fullvars
322+
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
323323
@set! state.structure = structure
324324
return state
325325
end

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ using ModelingToolkit: algeqs, EquationsView,
4040
dervars_range, diffvars_range, algvars_range,
4141
DiffGraph, complete!,
4242
get_fullvars, system_subset
43+
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic
4344

4445
using ModelingToolkit.DiffEqBase
4546
using ModelingToolkit.StaticArrays

src/structural_transformation/utils.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,15 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
228228
all_int_vars = false
229229
if !allow_symbolic
230230
if allow_parameter
231-
all(
232-
x -> ModelingToolkit.isparameter(x),
233-
vars(a)) || continue
231+
# if any of the variables in `a` are present in fullvars (taking into account arrays)
232+
if any(
233+
v -> any(isequal(v), fullvars) ||
234+
symbolic_type(v) == ArraySymbolic() &&
235+
Symbolics.shape(v) != Symbolics.Unknown() &&
236+
any(x -> any(isequal(x), fullvars), collect(v)),
237+
vars(a))
238+
continue
239+
end
234240
else
235241
continue
236242
end

src/systems/abstractsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,14 +648,14 @@ function (f::Initial)(x)
648648
iscall(x) && operation(x) isa Initial && return x
649649
result = if symbolic_type(x) == ArraySymbolic()
650650
# create an array for `Initial(array)`
651-
Symbolics.array_term(f, toparam(x))
651+
Symbolics.array_term(f, x)
652652
elseif iscall(x) && operation(x) == getindex
653653
# instead of `Initial(x[1])` create `Initial(x)[1]`
654654
# which allows parameter indexing to handle this case automatically.
655655
arr = arguments(x)[1]
656-
term(getindex, f(toparam(arr)), arguments(x)[2:end]...)
656+
term(getindex, f(arr), arguments(x)[2:end]...)
657657
else
658-
term(f, toparam(x))
658+
term(f, x)
659659
end
660660
# the result should be a parameter
661661
result = toparam(result)

src/systems/analysis_points.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ function linearization_function(sys::AbstractSystem,
950950
if output isa AnalysisPoint
951951
sys, (output_var,) = apply_transformation(AddVariable(output), sys)
952952
sys, (input_var,) = apply_transformation(GetInput(output), sys)
953-
push!(get_eqs(sys), output_var ~ input_var)
953+
@set! sys.eqs = [get_eqs(sys); output_var ~ input_var]
954954
else
955955
output_var = output
956956
end

src/systems/diffeqs/basic_transformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ function noise_to_brownians(sys::System; names::Union{Symbol, Vector{Symbol}} =
482482
"""))
483483
end
484484
brownvars = map(names) do name
485-
only(@brownian $name)
485+
unwrap(only(@brownian $name))
486486
end
487487

488488
terms = if ndims(neqs) == 1

src/systems/nonlinear/initializesystem.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,15 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
214214
initialization_eqs = filter(initialization_eqs) do eq
215215
empty!(vs)
216216
vars!(vs, eq; op = Initial)
217-
non_params = filter(!isparameter, vs)
217+
allpars = full_parameters(sys)
218+
for p in allpars
219+
if symbolic_type(p) == ArraySymbolic() &&
220+
Symbolics.shape(p) != Symbolics.Unknown()
221+
append!(allpars, Symbolics.scalarize(p))
222+
end
223+
end
224+
allpars = Set(allpars)
225+
non_params = filter(!in(allpars), vs)
218226
# error if non-parameters are present in the initialization equations
219227
if !isempty(non_params)
220228
throw(UnknownsInTimeIndependentInitializationError(eq, non_params))

src/systems/system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...)
206206
diffeqs = Equation[]
207207
othereqs = Equation[]
208208
for eq in eqs
209-
if !(eq.lhs isa Union{Symbolic, Number})
209+
if !(eq.lhs isa Union{Symbolic, Number, AbstractArray})
210210
push!(othereqs, eq)
211211
continue
212212
end

0 commit comments

Comments
 (0)