Skip to content

Commit 422c653

Browse files
authored
Merge pull request #1852 from SciML/myb/debugging
Add `debug_system` that gives a more helpful error message
2 parents e4d9101 + 187c824 commit 422c653

File tree

4 files changed

+89
-0
lines changed

4 files changed

+89
-0
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ include("systems/dependency_graphs.jl")
147147
include("systems/systemstructure.jl")
148148
using .SystemStructures
149149

150+
include("debugging.jl")
150151
include("systems/alias_elimination.jl")
151152
include("structural_transformation/StructuralTransformations.jl")
152153

@@ -210,5 +211,6 @@ export build_function
210211
export modelingtoolkitize
211212
export @variables, @parameters
212213
export @named, @nonamespace, @namespace, extend, compose
214+
export debug_system
213215

214216
end # module

src/debugging.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
const LOGGED_FUN = Set([log, sqrt, (^), /, inv])
2+
is_legal(::typeof(/), a, b) = is_legal(inv, b)
3+
is_legal(::typeof(inv), a) = !iszero(a)
4+
is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a)
5+
is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a)
6+
7+
struct LoggedFun{F}
8+
f::F
9+
args::Any
10+
end
11+
Base.nameof(lf::LoggedFun) = nameof(lf.f)
12+
SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real
13+
function (lf::LoggedFun)(args...)
14+
f = lf.f
15+
symbolic_args = lf.args
16+
if is_legal(f, args...)
17+
f(args...)
18+
else
19+
args_str = join(string.(symbolic_args .=> args), ", ", ", and ")
20+
throw(DomainError(args, "$(lf.f) errors with input(s): $args_str"))
21+
end
22+
end
23+
24+
function logged_fun(f, args...)
25+
# Currently we don't really support complex numbers
26+
term(LoggedFun(f, args), args..., type = Real)
27+
end
28+
29+
debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs)
30+
function debug_sub(ex)
31+
istree(ex) || return ex
32+
f = operation(ex)
33+
args = map(debug_sub, arguments(ex))
34+
f in LOGGED_FUN ? logged_fun(f, args...) :
35+
similarterm(ex, f, args, metadata = metadata(ex))
36+
end

src/systems/abstractsystem.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,40 @@ end
953953
"""
954954
$(SIGNATURES)
955955
956+
Replace functions with singularities with a function that errors with symbolic
957+
information. E.g.
958+
959+
```julia-repl
960+
julia> sys = debug_system(sys);
961+
962+
julia> prob = ODEProblem(sys, [], (0, 1.0));
963+
964+
julia> du = zero(prob.u0);
965+
966+
julia> prob.f(du, prob.u0, prob.p, 0.0)
967+
ERROR: DomainError with (-1.0,):
968+
log errors with input(s): -cos(Q(t)) => -1.0
969+
Stacktrace:
970+
[1] (::ModelingToolkit.LoggedFun{typeof(log)})(args::Float64)
971+
...
972+
```
973+
"""
974+
function debug_system(sys::AbstractSystem)
975+
if has_systems(sys) && !isempty(get_systems(sys))
976+
error("debug_system only works on systems with no sub-systems!")
977+
end
978+
if has_eqs(sys)
979+
@set! sys.eqs = debug_sub.(equations(sys))
980+
end
981+
if has_observed(sys)
982+
@set! sys.observed = debug_sub.(observed(sys))
983+
end
984+
return sys
985+
end
986+
987+
"""
988+
$(SIGNATURES)
989+
956990
Structurally simplify algebraic equations in a system and compute the
957991
topological sort of the observed equations. When `simplify=true`, the `simplify`
958992
function will be applied during the tearing process. It also takes kwargs

test/odesystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,20 @@ eqs = [D(q) ~ -p / L - F
889889
testdict = Dict([:name => "test"])
890890
@named sys = ODESystem(eqs, t, metadata = testdict)
891891
@test get_metadata(sys) == testdict
892+
893+
@variables t P(t)=0 Q(t)=2
894+
∂t = Differential(t)
895+
896+
eqs = [∂t(Q) ~ 1 / sin(P)
897+
∂t(P) ~ log(-cos(Q))]
898+
@named sys = ODESystem(eqs, t, [P, Q], [])
899+
sys = debug_system(sys);
900+
prob = ODEProblem(sys, [], (0, 1.0));
901+
du = zero(prob.u0);
902+
if VERSION < v"1.8"
903+
@test_throws DomainError prob.f(du, [1, 0], prob.p, 0.0)
904+
@test_throws DomainError prob.f(du, [0, 2], prob.p, 0.0)
905+
else
906+
@test_throws "-cos(Q(t))" prob.f(du, [1, 0], prob.p, 0.0)
907+
@test_throws "sin(P(t))" prob.f(du, [0, 2], prob.p, 0.0)
908+
end

0 commit comments

Comments
 (0)