Skip to content

Commit e647b0d

Browse files
committed
improve
1 parent 2b81db9 commit e647b0d

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

src/ProbProg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,11 @@ select(syms::Symbol...) = Set(syms)
966966
choicemap() = Constraint()
967967
const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
968968

969+
function with_compiled_cache(f)
970+
cache = CompiledFnCache()
971+
return f(cache)
972+
end
973+
969974
function metropolis_hastings(
970975
trace::ProbProgTrace,
971976
sel::Selection;

test/probprog/linear_regression.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@ function my_inference_program(xs, ys, num_iters)
3838

3939
trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint)
4040

41-
compiled_cache = ProbProg.CompiledFnCache()
42-
43-
for i in 1:num_iters
44-
trace, _ = ProbProg.metropolis_hastings(
45-
trace, ProbProg.select(:slope); compiled_cache
46-
)
47-
trace, _ = ProbProg.metropolis_hastings(
48-
trace, ProbProg.select(:intercept); compiled_cache
49-
)
41+
trace = ProbProg.with_compiled_cache() do cache
42+
local t = trace
43+
for _ in 1:num_iters
44+
t, _ = ProbProg.metropolis_hastings(
45+
t, ProbProg.select(:slope); compiled_cache=cache
46+
)
47+
t, _ = ProbProg.metropolis_hastings(
48+
t, ProbProg.select(:intercept); compiled_cache=cache
49+
)
50+
end
51+
return t
5052
end
5153

5254
choices = ProbProg.get_choices(trace)

0 commit comments

Comments
 (0)