|
1 | | -const LOGGED_FUN = Set([log, sqrt, (^), /, inv]) |
2 | | -is_legal(::typeof(/), a, b) = is_legal(inv, b) |
3 | | -is_legal(::typeof(inv), a) = !iszero(a) |
4 | | -is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a) |
5 | | -is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a) |
6 | | - |
| 1 | +struct LoggedFunctionException <: Exception |
| 2 | + msg::String |
| 3 | +end |
7 | 4 | struct LoggedFun{F} |
8 | 5 | f::F |
9 | 6 | args::Any |
| 7 | + error_nonfinite::Bool |
| 8 | +end |
| 9 | +function LoggedFunctionException(lf::LoggedFun, args, msg) |
| 10 | + LoggedFunctionException( |
| 11 | + "Function $(lf.f)($(join(lf.args, ", "))) " * msg * " with input" * |
| 12 | + join("\n " .* string.(lf.args .=> args)) # one line for each "var => val" for readability |
| 13 | + ) |
10 | 14 | end |
| 15 | +Base.showerror(io::IO, err::LoggedFunctionException) = print(io, err.msg) |
11 | 16 | Base.nameof(lf::LoggedFun) = nameof(lf.f) |
12 | 17 | SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real |
13 | 18 | function (lf::LoggedFun)(args...) |
14 | | - f = lf.f |
15 | | - symbolic_args = lf.args |
16 | | - if is_legal(f, args...) |
17 | | - f(args...) |
18 | | - else |
19 | | - args_str = join(string.(symbolic_args .=> args), ", ", ", and ") |
20 | | - throw(DomainError(args, "$(lf.f) errors with input(s): $args_str")) |
| 19 | + val = try |
| 20 | + lf.f(args...) # try to call with numerical input, as usual |
| 21 | + catch err |
| 22 | + throw(LoggedFunctionException(lf, args, "errors")) # Julia automatically attaches original error message |
21 | 23 | end |
| 24 | + if lf.error_nonfinite && !isfinite(val) |
| 25 | + throw(LoggedFunctionException(lf, args, "output non-finite value $val")) |
| 26 | + end |
| 27 | + return val |
22 | 28 | end |
23 | 29 |
|
24 | | -function logged_fun(f, args...) |
| 30 | +function logged_fun(f, args...; error_nonfinite = true) # remember to update error_nonfinite in debug_system() docstring |
25 | 31 | # Currently we don't really support complex numbers |
26 | | - term(LoggedFun(f, args), args..., type = Real) |
| 32 | + term(LoggedFun(f, args, error_nonfinite), args..., type = Real) |
27 | 33 | end |
28 | 34 |
|
29 | | -debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs) |
30 | | -function debug_sub(ex) |
| 35 | +function debug_sub(eq::Equation, funcs; kw...) |
| 36 | + debug_sub(eq.lhs, funcs; kw...) ~ debug_sub(eq.rhs, funcs; kw...) |
| 37 | +end |
| 38 | +function debug_sub(ex, funcs; kw...) |
31 | 39 | iscall(ex) || return ex |
32 | 40 | f = operation(ex) |
33 | | - args = map(debug_sub, arguments(ex)) |
34 | | - f in LOGGED_FUN ? logged_fun(f, args...) : |
| 41 | + args = map(ex -> debug_sub(ex, funcs; kw...), arguments(ex)) |
| 42 | + f in funcs ? logged_fun(f, args...; kw...) : |
35 | 43 | maketerm(typeof(ex), f, args, metadata(ex)) |
36 | 44 | end |
0 commit comments