File tree Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change 4545
4646accepts_output_grad (gen_fn:: DynamicDSLFunction ) = gen_fn. accepts_output_grad
4747
48+ mutable struct GFUntracedState
49+ params:: Dict{Symbol,Any}
50+ end
51+
4852function (gen_fn:: DynamicDSLFunction )(args... )
49- (_, _, retval) = propose (gen_fn, args )
50- retval
53+ state = GFUntracedState (gen_fn. params )
54+ gen_fn . julia_function (state, args ... )
5155end
5256
53- function exec (gf :: DynamicDSLFunction , state, args:: Tuple )
54- gf . julia_function (state, args... )
57+ function exec (gen_fn :: DynamicDSLFunction , state, args:: Tuple )
58+ gen_fn . julia_function (state, args... )
5559end
5660
5761# whether there is a gradient of score with respect to each argument
@@ -76,6 +80,13 @@ function dynamic_trace_impl(expr::Expr)
7680 end
7781end
7882
83+ # Defaults for untraced execution
84+ @inline traceat (state:: GFUntracedState , gen_fn:: GenerativeFunction , args, key) =
85+ gen_fn (args... )
86+
87+ @inline traceat (state:: GFUntracedState , dist:: Distribution , args, key) =
88+ random (dist, args... )
89+
7990# #######################
8091# trainable parameters #
8192# #######################
You can’t perform that action at this time.
0 commit comments