|
| 1 | +import AbstractTrees |
| 2 | + |
| 3 | +const inspect_metadata = Ref{Bool}(false) |
| 4 | +function AbstractTrees.nodevalue(x::Symbolic) |
| 5 | + istree(x) ? operation(x) : x |
| 6 | +end |
| 7 | + |
| 8 | +function AbstractTrees.nodevalue(x::BasicSymbolic) |
| 9 | + str = if !istree(x) |
| 10 | + string(exprtype(x), "(", x, ")") |
| 11 | + elseif isadd(x) |
| 12 | + string(exprtype(x), |
| 13 | + (scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) |
| 14 | + elseif ismul(x) |
| 15 | + string(exprtype(x), |
| 16 | + (scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) |
| 17 | + elseif isdiv(x) || ispow(x) |
| 18 | + string(exprtype(x)) |
| 19 | + else |
| 20 | + string(exprtype(x),"{", operation(x), "}") |
| 21 | + end |
| 22 | + |
| 23 | + if inspect_metadata[] && !isnothing(metadata(x)) |
| 24 | + str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x))) |
| 25 | + end |
| 26 | + Text(str) |
| 27 | +end |
| 28 | + |
| 29 | +function AbstractTrees.children(x::Symbolic) |
| 30 | + istree(x) ? arguments(x) : () |
| 31 | +end |
| 32 | + |
| 33 | +""" |
| 34 | + inspect([io::IO=stdout], expr; hint=true, metadata=false) |
| 35 | +
|
| 36 | +Inspect an expression tree `expr`. Uses AbstractTrees to print out an expression. |
| 37 | +
|
| 38 | +BasicSymbolic expressions will print the Unityper type (ADD, MUL, DIV, POW, SYM, TERM) and the relevant internals as the head, and the children in the subsequent lines as accessed by `arguments`. Other types will get printed as subtrees. Set `metadata=true` to print any metadata carried by the nodes. |
| 39 | +
|
| 40 | +Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expression or leafnode starting at line_number. |
| 41 | +""" |
| 42 | +function inspect end |
| 43 | + |
| 44 | +function inspect(io::IO, x::Symbolic; |
| 45 | + hint=true, |
| 46 | + metadata=inspect_metadata[]) |
| 47 | + |
| 48 | + prev_state = inspect_metadata[] |
| 49 | + inspect_metadata[] = metadata |
| 50 | + lines = readlines(IOBuffer(sprint(io->AbstractTrees.print_tree(io, x)))) |
| 51 | + inspect_metadata[] = prev_state |
| 52 | + digits = ceil(Int, log10(length(lines))) |
| 53 | + line_numbers = lpad.(string.(1:length(lines)), digits) |
| 54 | + print(io, join(string.(line_numbers, " ", lines), "\n")) |
| 55 | + hint && print(io, "\n\nHint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number") |
| 56 | +end |
| 57 | + |
| 58 | +function inspect(x; hint=true, metadata=inspect_metadata[]) |
| 59 | + inspect(stdout, x; hint=hint, metadata=metadata) |
| 60 | +end |
| 61 | + |
| 62 | +inspect(io::IO, x; kw...) = println(io, "Not Symbolic: $x") |
| 63 | + |
| 64 | +""" |
| 65 | + pluck(expr, n) |
| 66 | +
|
| 67 | +Pluck the `n`th subexpression from `expr` as given by pre-order DFS. |
| 68 | +This is the same as the node numbering in `inspect`. |
| 69 | +""" |
| 70 | +function pluck(x, item) |
| 71 | + collect(Iterators.take(AbstractTrees.PreOrderDFS(x), item))[end] |
| 72 | +end |
0 commit comments