@@ -2,11 +2,22 @@ mutable struct GFAssessState
22 choices:: ChoiceMap
33 weight:: Float64
44 visitor:: AddressVisitor
5- params:: Dict{Symbol,Any}
5+ active_gen_fn:: DynamicDSLFunction # mutated by splicing
6+ parameter_context:: Dict
7+
8+ function GFAssessState (gen_fn, choices, parameter_context)
9+ new (choices, 0.0 , AddressVisitor (), gen_fn, parameter_context)
10+ end
611end
712
8- function GFAssessState (choices, params:: Dict{Symbol,Any} )
9- GFAssessState (choices, 0. , AddressVisitor (), params)
13+ get_parameter_store (state:: GFAssessState ) = get_julia_store (state. parameter_context)
14+
15+ get_parameter_id (state:: GFAssessState , name:: Symbol ) = (state. active_gen_fn, name)
16+
17+ get_active_gen_fn (state:: GFAssessState ) = state. active_gen_fn
18+
19+ function set_active_gen_fn! (state:: GFAssessState , gen_fn:: GenerativeFunction )
20+ state. active_gen_fn = gen_fn
1021end
1122
1223function traceat (state:: GFAssessState , dist:: Distribution{T} ,
@@ -22,7 +33,7 @@ function traceat(state::GFAssessState, dist::Distribution{T},
2233 # update weight
2334 state. weight += logpdf (dist, retval, args... )
2435
25- retval
36+ return retval
2637end
2738
2839function traceat (state:: GFAssessState , gen_fn:: GenerativeFunction{T,U} ,
@@ -41,25 +52,19 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U},
4152 # update score
4253 state. weight += weight
4354
44- retval
45- end
46-
47- function splice (state:: GFAssessState , gen_fn:: DynamicDSLFunction , args:: Tuple )
48- prev_params = state. params
49- state. params = gen_fn. params
50- retval = exec (gen_fn, state, args)
51- state. params = prev_params
52- retval
55+ return retval
5356end
5457
55- function assess (gen_fn:: DynamicDSLFunction , args:: Tuple , choices:: ChoiceMap )
56- state = GFAssessState (choices, gen_fn. params)
58+ function assess (
59+ gen_fn:: DynamicDSLFunction , args:: Tuple , choices:: ChoiceMap ;
60+ parameter_context= default_parameter_context)
61+ state = GFAssessState (gen_fn, choices, parameter_context)
5762 retval = exec (gen_fn, state, args)
5863
5964 unvisited = get_unvisited (get_visited (state. visitor), choices)
6065 if ! isempty (unvisited)
6166 error (" Assess did not visit the following constraint addresses:\n $unvisited " )
6267 end
6368
64- (state. weight, retval)
69+ return (state. weight, retval)
6570end
0 commit comments