Skip to content

Commit a96a779

Browse files
authored
Merge pull request #322 from ztangent/remove-untraced-overhead
Remove overhead during untraced execution of dynamic @gen functions.
2 parents dec8c53 + 37aa5e4 commit a96a779

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/dynamic/dynamic.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ end
4545

4646
accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad
4747

48+
mutable struct GFUntracedState
49+
params::Dict{Symbol,Any}
50+
end
51+
4852
function (gen_fn::DynamicDSLFunction)(args...)
49-
(_, _, retval) = propose(gen_fn, args)
50-
retval
53+
state = GFUntracedState(gen_fn.params)
54+
gen_fn.julia_function(state, args...)
5155
end
5256

53-
function exec(gf::DynamicDSLFunction, state, args::Tuple)
54-
gf.julia_function(state, args...)
57+
function exec(gen_fn::DynamicDSLFunction, state, args::Tuple)
58+
gen_fn.julia_function(state, args...)
5559
end
5660

5761
# whether there is a gradient of score with respect to each argument
@@ -76,6 +80,13 @@ function dynamic_trace_impl(expr::Expr)
7680
end
7781
end
7882

83+
# Defaults for untraced execution
84+
@inline traceat(state::GFUntracedState, gen_fn::GenerativeFunction, args, key) =
85+
gen_fn(args...)
86+
87+
@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
88+
random(dist, args...)
89+
7990
########################
8091
# trainable parameters #
8192
########################

0 commit comments

Comments
 (0)