Skip to content

Commit 7a458b8

Browse files
authored
Merge pull request #382 from probcomp/20200216_ztangent_irprettyprint
Pretty print for static IR
2 parents bda1ef7 + 23c63ba commit 7a458b8

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

src/static_ir/print_ir.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
Base.show(io::IO, ::MIME"text/plain", ir::StaticIR) =
2+
print_ir(io, ir)
3+
4+
function print_ir(io::IO, ir::StaticIR)
5+
println(io, "== Static IR ==")
6+
args = join((string(arg.name) for arg in ir.arg_nodes), ", ")
7+
println(io, "Arguments: ($args)")
8+
for node in ir.nodes
9+
node in ir.arg_nodes && continue
10+
print(io, " "); print_ir(io, node); println(io)
11+
end
12+
print(io, " return $(ir.return_node.name)")
13+
end
14+
15+
function print_ir(io::IO, node::TrainableParameterNode)
16+
print(io, "@param $(node.name)::$(node.typ)")
17+
end
18+
19+
function print_ir(io::IO, node::JuliaNode)
20+
inputs = join((string(i.name) for i in node.inputs), ", ")
21+
print(io, "$(node.name) = $(node.fn)($inputs)")
22+
end
23+
24+
function print_ir(io::IO, node::GenerativeFunctionCallNode)
25+
inputs = join((string(i.name) for i in node.inputs), ", ")
26+
gen_fn_name = ir_name(node.generative_function)
27+
print(io, "$(node.name) = @trace($(gen_fn_name)($inputs), :$(node.addr))")
28+
end
29+
30+
function print_ir(io::IO, node::RandomChoiceNode)
31+
inputs = join((string(i.name) for i in node.inputs), ", ")
32+
print(io, "$(node.name) = @trace($(node.dist)($inputs), :$(node.addr))")
33+
end
34+
35+
ir_name(fn::GenerativeFunction) = nameof(typeof(fn))
36+
ir_name(fn::DynamicDSLFunction) = nameof(fn.julia_function)

src/static_ir/static_ir.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
5252
params::Dict{Symbol,Any}
5353
end
5454
(gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3]
55+
$(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir))
5556
$(GlobalRef(Gen, :get_ir))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir))
5657
$(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name
5758
$(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads))
@@ -63,6 +64,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
6364
Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}())))
6465
end
6566

67+
include("print_ir.jl")
6668
include("render_ir.jl")
6769

6870
###########################

test/static_ir/static_ir.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ z = add_julia_node!(builder, (u, v) -> u + v, inputs=[u, v], name=:z)
2222
set_return_node!(builder, z)
2323
ir = build_ir(builder)
2424
bar = eval(generate_generative_function(ir, :bar, track_diffs=false, cache_julia_nodes=false))
25+
@test occursin("== Static IR ==", repr("text/plain", ir))
2526

2627
#@gen (static, nojuliacache) function foo(a, b)
2728
#@param theta::Float64
@@ -45,6 +46,7 @@ w = add_julia_node!(builder, (z, a, theta) -> z + 1 + a + theta, inputs=[z, a, t
4546
set_return_node!(builder, w)
4647
ir = build_ir(builder)
4748
foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia_nodes=false))
49+
@test occursin("== Static IR ==", repr("text/plain", ir))
4850

4951
theta_val = rand()
5052
set_param!(foo, :theta, theta_val)
@@ -58,6 +60,7 @@ one = add_constant_node!(builder, 2)
5860
set_return_node!(builder, one)
5961
ir = build_ir(builder)
6062
const_fn = eval(generate_generative_function(ir, :const_fn, track_diffs=false, cache_julia_nodes=false))
63+
@test occursin("== Static IR ==", repr("text/plain", ir))
6164

6265
Gen.load_generated_functions()
6366

0 commit comments

Comments
 (0)