Skip to content

Commit 3572ab9

Browse files
committed
Pretty print for static IR, define get_ir on static functions directly.
1 parent bda1ef7 commit 3572ab9

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

src/dynamic/dynamic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
1919
accepts_output_grad::Bool
2020
end
2121

22+
Base.nameof(gen_fn::DynamicDSLFunction) =
23+
nameof(gen_fn.julia_function)
24+
2225
function DynamicDSLFunction(arg_types::Vector{Type},
2326
arg_defaults::Vector{Union{Some{Any},Nothing}},
2427
julia_function::Function,

src/static_ir/print_ir.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
Base.show(io::IO, ::MIME"text/plain", ir::StaticIR) =
2+
print_ir(io, ir)
3+
4+
function print_ir(io::IO, ir::StaticIR)
5+
println(io, "== Static IR ==")
6+
args = join((string(arg.name) for arg in ir.arg_nodes), ", ")
7+
println(io, "Arguments: ($args)")
8+
for node in ir.nodes
9+
node in ir.arg_nodes && continue
10+
print(io, " "); print_ir(io, node); println(io)
11+
end
12+
print(io, " return $(ir.return_node.name)")
13+
end
14+
15+
function print_ir(io::IO, node::TrainableParameterNode)
16+
print(io, "@param $(node.name)::$(node.typ)")
17+
end
18+
19+
function print_ir(io::IO, node::JuliaNode)
20+
inputs = join((string(i.name) for i in node.inputs), ", ")
21+
print(io, "$(node.name) = $(node.fn)($inputs)")
22+
end
23+
24+
function print_ir(io::IO, node::GenerativeFunctionCallNode)
25+
inputs = join((string(i.name) for i in node.inputs), ", ")
26+
gen_fn_name = nameof(node.generative_function)
27+
print(io, "$(node.name) = @trace($(gen_fn_name)($inputs), :$(node.addr))")
28+
end
29+
30+
function print_ir(io::IO, node::RandomChoiceNode)
31+
inputs = join((string(i.name) for i in node.inputs), ", ")
32+
print(io, "$(node.name) = @trace($(node.dist)($inputs), :$(node.addr))")
33+
end

src/static_ir/static_ir.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ Most generative function interface methods are generated from the intermediate r
2727
"""
2828
abstract type StaticIRGenerativeFunction{T,U} <: GenerativeFunction{T,U} end
2929

30+
Base.nameof(gen_fn::StaticIRGenerativeFunction) =
31+
nameof(type_of(gen_fn))
32+
3033
function get_ir end
3134
function get_gen_fn_type end
3235
function get_options end
@@ -52,6 +55,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
5255
params::Dict{Symbol,Any}
5356
end
5457
(gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3]
58+
$(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir))
5559
$(GlobalRef(Gen, :get_ir))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir))
5660
$(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name
5761
$(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads))
@@ -63,6 +67,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
6367
Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}())))
6468
end
6569

70+
include("print_ir.jl")
6671
include("render_ir.jl")
6772

6873
###########################

0 commit comments

Comments
 (0)