Skip to content

Commit 561b051

Browse files
committed
better print
1 parent 91a0850 commit 561b051

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/ProbProg.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,5 +291,71 @@ function simulate_internal(
291291
return result
292292
end
293293

294+
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
295+
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
296+
VERT = '\u2502'
297+
PLUS = '\u251C'
298+
HORZ = '\u2500'
299+
LAST = '\u2514'
300+
301+
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
302+
indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
303+
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
304+
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
305+
306+
for i in vert_bars
307+
indent_vert[i] = VERT
308+
indent[i] = VERT
309+
indent_last[i] = VERT
310+
end
311+
312+
indent_vert_str = join(indent_vert)
313+
indent_str = join(indent)
314+
indent_last_str = join(indent_last)
315+
316+
sorted_choices = sort(collect(trace.choices); by=x -> x[1])
317+
n = length(sorted_choices)
318+
319+
if trace.retval !== nothing
320+
n += 1
321+
end
322+
323+
cur = 1
324+
325+
if trace.retval !== nothing
326+
print(io, indent_vert_str)
327+
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
328+
cur += 1
329+
end
330+
331+
for (key, value) in sorted_choices
332+
print(io, indent_vert_str)
333+
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
334+
cur += 1
335+
end
336+
end
337+
338+
function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
339+
println(io, "ProbProgTrace:")
340+
if isempty(trace.choices) && trace.retval === nothing
341+
println(io, " (empty)")
342+
else
343+
_show_pretty(io, trace, 0, ())
344+
end
345+
end
346+
347+
function Base.show(io::IO, trace::ProbProgTrace)
348+
if get(io, :compact, false)
349+
choices_count = length(trace.choices)
350+
has_retval = trace.retval !== nothing
351+
print(io, "ProbProgTrace($(choices_count) choices")
352+
if has_retval
353+
print(io, ", retval=$(trace.retval)")
354+
end
355+
print(io, ")")
356+
else
357+
show(io, MIME"text/plain"(), trace)
358+
end
359+
end
294360

295361
end

0 commit comments

Comments
 (0)