Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...)
end

@unpack graph, var_to_diff, solvable_graph = state.structure
mm = alias_eliminate_graph!(state, mm)
mm = alias_eliminate_graph!(state, mm; kwargs...)
s = state.structure
for g in (s.graph, s.solvable_graph)
g === nothing && continue
Expand Down Expand Up @@ -347,27 +347,29 @@ function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff)
(rank1, rank2, rank3, pivots)
end

function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL)
function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL; fully_determined = true, kwargs...)
@unpack structure = state
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
# subsystem of the system we're interested in.
#
ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(structure, ils)

## Step 2: Simplify the system using the Bareiss factorization
rk1vars = BitSet(@view pivots[1:rank1])
for v in solvable_variables
v in rk1vars && continue
@set! ils.nparentrows += 1
push!(ils.nzrows, ils.nparentrows)
push!(ils.row_cols, [v])
push!(ils.row_vals, [convert(eltype(ils), 1)])
add_vertex!(graph, SRC)
add_vertex!(solvable_graph, SRC)
add_edge!(graph, ils.nparentrows, v)
add_edge!(solvable_graph, ils.nparentrows, v)
add_vertex!(eq_to_diff)
if fully_determined == true
## Step 2: Simplify the system using the Bareiss factorization
rk1vars = BitSet(@view pivots[1:rank1])
for v in solvable_variables
v in rk1vars && continue
@set! ils.nparentrows += 1
push!(ils.nzrows, ils.nparentrows)
push!(ils.row_cols, [v])
push!(ils.row_vals, [convert(eltype(ils), 1)])
add_vertex!(graph, SRC)
add_vertex!(solvable_graph, SRC)
add_edge!(graph, ils.nparentrows, v)
add_edge!(solvable_graph, ils.nparentrows, v)
add_vertex!(eq_to_diff)
end
end

return ils
Expand Down
4 changes: 2 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ function _mtkcompile!(state::TearingState; simplify = false,
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
end
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
if check_consistency
fully_determined = ModelingToolkit.check_consistency(
state, orig_inputs; nothrow = fully_determined === nothing)
Expand All @@ -765,7 +765,7 @@ function _mtkcompile!(state::TearingState; simplify = false,
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
sys = pantelides_reassemble(state, var_eq_matching)
state = TearingState(sys)
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, fully_determined, kwargs...)
else
Expand Down
5 changes: 2 additions & 3 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,14 +455,13 @@ sol = solve(prob, Tsit5())
# Initialize with an observed variable
prob = ODEProblem(simpsys, [z => 0.0], tspan, guesses = [x => 2.0, y => 4.0])
sol = solve(prob, Tsit5())
@test sol.u[1] == [0.0, 0.0]
@test sol[z, 1] == 0.0

prob = ODEProblem(simpsys, [z => 1.0, y => 1.0], tspan, guesses = [x => 2.0])
sol = solve(prob, Tsit5())
@test sol[[x, y], 1] == [0.0, 1.0]

# This should warn, but logging tests can't be marked as broken
@test_logs prob = ODEProblem(simpsys, [], tspan, guesses = [x => 2.0])
@test_warn "underdetermined" prob = ODEProblem(simpsys, [], tspan, guesses = [x => 2.0, y => 1.0])

# Late Binding initialization_eqs
# https://github.com/SciML/ModelingToolkit.jl/issues/2787
Expand Down
2 changes: 1 addition & 1 deletion test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ end
D(x) ~ p2,
x2 ~ p_1(x)
]
@mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event])
@mtkcompile sys = System(eq, t, [x, x2], [p_1, p2], discrete_events = [event])

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob)
Expand Down
Loading