Skip to content

Commit 71e8f34

Browse files
authored
Merge pull request #2017 from SciML/myb/input_err
Better extra variable reporting and reverse the arrow in `compute_diff_label`
2 parents 2769903 + 63eb674 commit 71e8f34

File tree

4 files changed

+41
-11
lines changed

4 files changed

+41
-11
lines changed

src/structural_transformation/utils.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,28 @@ function BipartiteGraphs.maximal_matching(s::SystemStructure, eqfilter = eq -> t
1212
maximal_matching(s.graph, eqfilter, varfilter)
1313
end
1414

15-
function error_reporting(state, bad_idxs, n_highest_vars, iseqs)
15+
function error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
1616
io = IOBuffer()
17+
neqs = length(equations(state))
1718
if iseqs
1819
error_title = "More equations than variables, here are the potential extra equation(s):\n"
1920
out_arr = equations(state)[bad_idxs]
2021
else
2122
error_title = "More variables than equations, here are the potential extra variable(s):\n"
2223
out_arr = state.fullvars[bad_idxs]
24+
unset_inputs = intersect(out_arr, orig_inputs)
25+
n_missing_eqs = n_highest_vars - neqs
26+
n_unset_inputs = length(unset_inputs)
27+
if n_unset_inputs > 0
28+
println(io, "In particular, the unset input(s) are:")
29+
Base.print_array(io, unset_inputs)
30+
println(io)
31+
println(io, "The rest of potentially unset variable(s) are:")
32+
end
2333
end
2434

2535
Base.print_array(io, out_arr)
2636
msg = String(take!(io))
27-
neqs = length(equations(state))
2837
if iseqs
2938
throw(ExtraEquationsSystemException("The system is unbalanced. There are " *
3039
"$n_highest_vars highest order derivative variables "
@@ -43,7 +52,7 @@ end
4352
###
4453
### Structural check
4554
###
46-
function check_consistency(state::TearingState, ag = nothing)
55+
function check_consistency(state::TearingState, ag, orig_inputs)
4756
fullvars = state.fullvars
4857
@unpack graph, var_to_diff = state.structure
4958
n_highest_vars = count(v -> var_to_diff[v] === nothing &&
@@ -64,7 +73,7 @@ function check_consistency(state::TearingState, ag = nothing)
6473
else
6574
bad_idxs = findall(isequal(unassigned), var_eq_matching)
6675
end
67-
error_reporting(state, bad_idxs, n_highest_vars, iseqs)
76+
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
6877
end
6978

7079
# This is defined to check if Pantelides algorithm terminates. For more

src/systems/abstractsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ function linearization_function(sys::AbstractSystem, inputs,
11931193
return lin_fun, sys
11941194
end
11951195

1196-
function markio!(state, inputs, outputs; check = true)
1196+
function markio!(state, orig_inputs, inputs, outputs; check = true)
11971197
fullvars = state.fullvars
11981198
inputset = Dict{Any, Bool}(i => false for i in inputs)
11991199
outputset = Dict{Any, Bool}(o => false for o in outputs)
@@ -1207,6 +1207,9 @@ function markio!(state, inputs, outputs; check = true)
12071207
outputset[v] = true
12081208
fullvars[i] = v
12091209
else
1210+
if isinput(v)
1211+
push!(orig_inputs, v)
1212+
end
12101213
v = setio(v, false, false)
12111214
fullvars[i] = v
12121215
end
@@ -1221,7 +1224,7 @@ function markio!(state, inputs, outputs; check = true)
12211224
check && (all(values(outputset)) ||
12221225
error("Some specified outputs were not found in system. The following Dict indicates the found variables ",
12231226
outputset))
1224-
state
1227+
state, orig_inputs
12251228
end
12261229

12271230
"""

src/systems/systemstructure.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
200200
extra_eqs::Vector
201201
end
202202

203+
function Base.show(io::IO, state::TearingState)
204+
print(io, "TearingState of ", typeof(state.sys))
205+
end
206+
203207
struct EquationsView{T} <: AbstractVector{Any}
204208
ts::TearingState{T}
205209
end
@@ -386,9 +390,9 @@ Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.b
386390
function compute_diff_label(diff_graph, i)
387391
di = i - 1 <= length(diff_graph) ? diff_graph[i - 1] : nothing
388392
ii = i - 1 <= length(invview(diff_graph)) ? invview(diff_graph)[i - 1] : nothing
389-
return Label(string(di === nothing ? "" : string(di, ''),
393+
return Label(string(di === nothing ? "" : string(di, ''),
390394
di !== nothing && ii !== nothing ? " " : "",
391-
ii === nothing ? "" : string(ii, '')))
395+
ii === nothing ? "" : string(ii, '')))
392396
end
393397
function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer)
394398
checkbounds(bgpm, i, j)
@@ -519,11 +523,14 @@ end
519523
function _structural_simplify!(state::TearingState, io; simplify = false,
520524
check_consistency = true, kwargs...)
521525
has_io = io !== nothing
522-
has_io && ModelingToolkit.markio!(state, io...)
526+
orig_inputs = Set()
527+
if has_io
528+
ModelingToolkit.markio!(state, orig_inputs, io...)
529+
end
523530
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
524531
sys, ag = ModelingToolkit.alias_elimination!(state; kwargs...)
525532
if check_consistency
526-
ModelingToolkit.check_consistency(state, ag)
533+
ModelingToolkit.check_consistency(state, ag, orig_inputs)
527534
end
528535
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
529536
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]

test/input_output_handling.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
using ModelingToolkit, Symbolics, Test
22
using ModelingToolkit: get_namespace, has_var, inputs, outputs, is_bound, bound_inputs,
3-
unbound_inputs, bound_outputs, unbound_outputs, isinput, isoutput
3+
unbound_inputs, bound_outputs, unbound_outputs, isinput, isoutput,
4+
ExtraVariablesSystemException
5+
6+
@variables t xx(t) some_input(t) [input = true]
7+
D = Differential(t)
8+
eqs = [D(xx) ~ some_input]
9+
@named model = ODESystem(eqs, t)
10+
@test_throws ExtraVariablesSystemException structural_simplify(model, ((), ()))
11+
if VERSION >= v"1.8"
12+
err = "In particular, the unset input(s) are:\n some_input(t)"
13+
@test_throws err structural_simplify(model, ((), ()))
14+
end
415

516
# Test input handling
617
@parameters tv

0 commit comments

Comments
 (0)