|
| 1 | +Base.show(io::IO, ::MIME"text/plain", ir::StaticIR) = |
| 2 | + print_ir(io, ir) |
| 3 | + |
| 4 | +function print_ir(io::IO, ir::StaticIR) |
| 5 | + println(io, "== Static IR ==") |
| 6 | + args = join((string(arg.name) for arg in ir.arg_nodes), ", ") |
| 7 | + println(io, "Arguments: ($args)") |
| 8 | + for node in ir.nodes |
| 9 | + node in ir.arg_nodes && continue |
| 10 | + print(io, " "); print_ir(io, node); println(io) |
| 11 | + end |
| 12 | + print(io, " return $(ir.return_node.name)") |
| 13 | +end |
| 14 | + |
| 15 | +function print_ir(io::IO, node::TrainableParameterNode) |
| 16 | + print(io, "@param $(node.name)::$(node.typ)") |
| 17 | +end |
| 18 | + |
| 19 | +function print_ir(io::IO, node::JuliaNode) |
| 20 | + inputs = join((string(i.name) for i in node.inputs), ", ") |
| 21 | + print(io, "$(node.name) = $(node.fn)($inputs)") |
| 22 | +end |
| 23 | + |
| 24 | +function print_ir(io::IO, node::GenerativeFunctionCallNode) |
| 25 | + inputs = join((string(i.name) for i in node.inputs), ", ") |
| 26 | + gen_fn_name = nameof(node.generative_function) |
| 27 | + print(io, "$(node.name) = @trace($(gen_fn_name)($inputs), :$(node.addr))") |
| 28 | +end |
| 29 | + |
| 30 | +function print_ir(io::IO, node::RandomChoiceNode) |
| 31 | + inputs = join((string(i.name) for i in node.inputs), ", ") |
| 32 | + print(io, "$(node.name) = @trace($(node.dist)($inputs), :$(node.addr))") |
| 33 | +end |
0 commit comments