Skip to content

Commit 94b9e3a

Browse files
committed
reorganize
1 parent 4fe55a6 commit 94b9e3a

File tree

7 files changed

+542
-526
lines changed

7 files changed

+542
-526
lines changed

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ include("Tracing.jl")
189189
include("Compiler.jl")
190190

191191
include("Overlay.jl")
192-
include("ProbProg.jl")
192+
include("probprog/ProbProg.jl")
193193

194194
# Serialization
195195
include("serialization/Serialization.jl")

src/probprog/Display.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
2+
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
3+
VERT = '\u2502'
4+
PLUS = '\u251C'
5+
HORZ = '\u2500'
6+
LAST = '\u2514'
7+
8+
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
9+
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
10+
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
11+
12+
for i in vert_bars
13+
indent_vert[i] = VERT
14+
indent[i] = VERT
15+
indent_last[i] = VERT
16+
end
17+
18+
indent_vert_str = join(indent_vert)
19+
indent_str = join(indent)
20+
indent_last_str = join(indent_last)
21+
22+
sorted_choices = sort(collect(trace.choices); by=x -> x[1])
23+
n = length(sorted_choices)
24+
25+
if trace.retval !== nothing
26+
n += 1
27+
end
28+
29+
if trace.weight !== nothing
30+
n += 1
31+
end
32+
33+
cur = 1
34+
35+
if trace.retval !== nothing
36+
print(io, indent_vert_str)
37+
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
38+
cur += 1
39+
end
40+
41+
if trace.weight !== nothing
42+
print(io, indent_vert_str)
43+
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
44+
cur += 1
45+
end
46+
47+
for (key, value) in sorted_choices
48+
print(io, indent_vert_str)
49+
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
50+
cur += 1
51+
end
52+
53+
sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
54+
n += length(sorted_subtraces)
55+
56+
for (key, subtrace) in sorted_subtraces
57+
print(io, indent_vert_str)
58+
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
59+
_show_pretty(
60+
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
61+
)
62+
cur += 1
63+
end
64+
end
65+
66+
function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
67+
println(io, "ProbProgTrace:")
68+
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
69+
println(io, " (empty)")
70+
else
71+
_show_pretty(io, trace, 0, ())
72+
end
73+
end
74+
75+
function Base.show(io::IO, trace::ProbProgTrace)
76+
if get(io, :compact, false)
77+
choices_count = length(trace.choices)
78+
has_retval = trace.retval !== nothing
79+
print(io, "ProbProgTrace($(choices_count) choices")
80+
if has_retval
81+
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
82+
end
83+
print(io, ")")
84+
else
85+
show(io, MIME"text/plain"(), trace)
86+
end
87+
end

0 commit comments

Comments
 (0)