Skip to content

Commit a94caad

Browse files
authored
Merge pull request #538 from JuliaSymbolics/s/inspect
inspect & pluck
2 parents 8f9f5f1 + 33278b6 commit a94caad

File tree

13 files changed

+161
-53
lines changed

13 files changed

+161
-53
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5050
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5151
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
5252
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
53+
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
5354
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5455
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5556

5657
[targets]
57-
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "Test", "Zygote"]
58+
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]

src/SymbolicUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ include("methods.jl")
2727
# LinkedList, simplification utilities
2828
include("utils.jl")
2929

30+
# Tree inspection
31+
include("inspect.jl")
3032
export Rewriters
3133

3234
# A library for composing together expr -> expr functions

src/inspect.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import AbstractTrees
2+
3+
const inspect_metadata = Ref{Bool}(false)
4+
function AbstractTrees.nodevalue(x::Symbolic)
5+
istree(x) ? operation(x) : x
6+
end
7+
8+
function AbstractTrees.nodevalue(x::BasicSymbolic)
9+
str = if !istree(x)
10+
string(exprtype(x), "(", x, ")")
11+
elseif isadd(x)
12+
string(exprtype(x),
13+
(scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
14+
elseif ismul(x)
15+
string(exprtype(x),
16+
(scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
17+
elseif isdiv(x) || ispow(x)
18+
string(exprtype(x))
19+
else
20+
string(exprtype(x),"{", operation(x), "}")
21+
end
22+
23+
if inspect_metadata[] && !isnothing(metadata(x))
24+
str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x)))
25+
end
26+
Text(str)
27+
end
28+
29+
function AbstractTrees.children(x::Symbolic)
30+
istree(x) ? arguments(x) : ()
31+
end
32+
33+
"""
34+
inspect([io::IO=stdout], expr; hint=true, metadata=false)
35+
36+
Inspect an expression tree `expr`. Uses AbstractTrees to print out an expression.
37+
38+
BasicSymbolic expressions will print the Unityper type (ADD, MUL, DIV, POW, SYM, TERM) and the relevant internals as the head, and the children in the subsequent lines as accessed by `arguments`. Other types will get printed as subtrees. Set `metadata=true` to print any metadata carried by the nodes.
39+
40+
Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expression or leafnode starting at line_number.
41+
"""
42+
function inspect end
43+
44+
function inspect(io::IO, x::Symbolic;
45+
hint=true,
46+
metadata=inspect_metadata[])
47+
48+
prev_state = inspect_metadata[]
49+
inspect_metadata[] = metadata
50+
lines = readlines(IOBuffer(sprint(io->AbstractTrees.print_tree(io, x))))
51+
inspect_metadata[] = prev_state
52+
digits = ceil(Int, log10(length(lines)))
53+
line_numbers = lpad.(string.(1:length(lines)), digits)
54+
print(io, join(string.(line_numbers, " ", lines), "\n"))
55+
hint && print(io, "\n\nHint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number")
56+
end
57+
58+
function inspect(x; hint=true, metadata=inspect_metadata[])
59+
inspect(stdout, x; hint=hint, metadata=metadata)
60+
end
61+
62+
inspect(io::IO, x; kw...) = println(io, "Not Symbolic: $x")
63+
64+
"""
65+
pluck(expr, n)
66+
67+
Pluck the `n`th subexpression from `expr` as given by pre-order DFS.
68+
This is the same as the node numbering in `inspect`.
69+
"""
70+
function pluck(x, item)
71+
collect(Iterators.take(AbstractTrees.PreOrderDFS(x), item))[end]
72+
end

src/types.jl

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -559,54 +559,6 @@ function basic_similarterm(t, f, args, stype; metadata=nothing)
559559
end
560560
end
561561

562-
###
563-
### Tree print
564-
###
565-
566-
import AbstractTrees
567-
568-
struct TreePrint
569-
op
570-
x
571-
end
572-
573-
function AbstractTrees.children(x::BasicSymbolic)
574-
if isterm(x) || ispow(x)
575-
return arguments(x)
576-
elseif isadd(x) || ismul(x)
577-
children = Any[x.coeff]
578-
for (key, coeff) in pairs(x.dict)
579-
if coeff == 1
580-
push!(children, key)
581-
else
582-
push!(children, TreePrint(isadd(x) ? (:*) : (:^), (key, coeff)))
583-
end
584-
end
585-
return children
586-
end
587-
end
588-
589-
AbstractTrees.children(x::TreePrint) = [x.x[1], x.x[2]]
590-
591-
print_tree(x; show_type=false, maxdepth=Inf, kw...) = print_tree(stdout, x; show_type=show_type, maxdepth=maxdepth, kw...)
592-
593-
function print_tree(_io::IO, x::BasicSymbolic; show_type=false, kw...)
594-
if isterm(x) || isadd(x) || ismul(x) || ispow(x) || isdiv(x)
595-
AbstractTrees.print_tree(_io, x; withinds=true, kw...) do io, y, inds
596-
if istree(y)
597-
print(io, operation(y))
598-
elseif y isa TreePrint
599-
print(io, "(", y.op, ")")
600-
else
601-
print(io, y)
602-
end
603-
if !(y isa TreePrint) && show_type
604-
print(io, " [", typeof(y), "]")
605-
end
606-
end
607-
end
608-
end
609-
610562
###
611563
### Metadata
612564
###

test/basics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,18 @@ end
188188
@test repr((-1)^a) == "(-1)^a"
189189
end
190190

191+
@testset "inspect" begin
192+
@syms x y z
193+
y = SymbolicUtils.setmetadata(y, Integer, 42) # Set some metadata
194+
ex = z*(2x + 3y + 1)^2/(z+2x)
195+
@test_reference "inspect_output/ex.txt" sprint(io->SymbolicUtils.inspect(io, ex))
196+
@test_reference "inspect_output/ex-md.txt" sprint(io->SymbolicUtils.inspect(io, ex, metadata=true))
197+
@test_reference "inspect_output/ex-nohint.txt" sprint(io->SymbolicUtils.inspect(io, ex, hint=false))
198+
@test SymbolicUtils.pluck(ex, 8) == 2
199+
@test_reference "inspect_output/sub10.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 10)))
200+
@test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14)))
201+
end
202+
191203
@testset "similarterm" begin
192204
@syms a b c
193205
@test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1))

test/fuzz.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ include("fuzzlib.jl")
22

33
using Random: seed!
44

5-
seed!(6174)
6-
@testset "Fuzz test" begin
5+
seed!(8258)
76
@time @testset "expand fuzz" begin
87
for i=1:500
98
i % 100 == 0 && @info "expand fuzz" iter=i
@@ -45,4 +44,3 @@ seed!(6174)
4544
fuzz_addmulpow(4)
4645
end
4746
end
48-
end

test/fuzzlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ const num_spec = let
4343
()->rand([a b c d e f])]
4444

4545
binops = SymbolicUtils.diadic
46-
nopow = filter(x->x!==(^), binops)
46+
nopow = setdiff(binops, [(^), besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk])
4747
twoargfns = vcat(nopow, (x,y)->x isa Union{Int, Rational, Complex{<:Rational}} ? x * y : x^y)
4848
fns = vcat(1 .=> vcat(SymbolicUtils.monadic, [one, zero]),
4949
2 .=> vcat(twoargfns, fill(+, 5), [-,-], fill(*, 5), fill(/, 40)),

test/inspect_output/ex-md.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
1 DIV
2+
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3+
3 │ ├─ SYM(z)
4+
4 │ └─ POW
5+
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6+
6 │ │ ├─ 1
7+
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8+
8 │ │ │ ├─ 2
9+
9 │ │ │ └─ SYM(x)
10+
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11+
11 │ │ ├─ 3
12+
12 │ │ └─ SYM(y) metadata=(Integer => 42,)
13+
13 │ └─ 2
14+
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15+
15 ├─ SYM(z)
16+
16 └─ MUL(scalar = 2, powers = (x => 1,))
17+
17 ├─ 2
18+
18 └─ SYM(x)
19+
20+
Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number

test/inspect_output/ex-nohint.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
1 DIV
2+
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3+
3 │ ├─ SYM(z)
4+
4 │ └─ POW
5+
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6+
6 │ │ ├─ 1
7+
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8+
8 │ │ │ ├─ 2
9+
9 │ │ │ └─ SYM(x)
10+
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11+
11 │ │ ├─ 3
12+
12 │ │ └─ SYM(y)
13+
13 │ └─ 2
14+
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15+
15 ├─ SYM(z)
16+
16 └─ MUL(scalar = 2, powers = (x => 1,))
17+
17 ├─ 2
18+
18 └─ SYM(x)

test/inspect_output/ex.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
1 DIV
2+
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3+
3 │ ├─ SYM(z)
4+
4 │ └─ POW
5+
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6+
6 │ │ ├─ 1
7+
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8+
8 │ │ │ ├─ 2
9+
9 │ │ │ └─ SYM(x)
10+
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11+
11 │ │ ├─ 3
12+
12 │ │ └─ SYM(y)
13+
13 │ └─ 2
14+
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15+
15 ├─ SYM(z)
16+
16 └─ MUL(scalar = 2, powers = (x => 1,))
17+
17 ├─ 2
18+
18 └─ SYM(x)
19+
20+
Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number

0 commit comments

Comments
 (0)