Skip to content

Commit d6f636c

Browse files
committed
fixes
1 parent ad05f34 commit d6f636c

File tree

13 files changed

+64
-156
lines changed

13 files changed

+64
-156
lines changed

src/dynamic/assess.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,22 @@ mutable struct GFAssessState
22
choices::ChoiceMap
33
weight::Float64
44
visitor::AddressVisitor
5-
params::Dict{Symbol,Any}
5+
active_gen_fn::DynamicDSLFunction # mutated by splicing
6+
parameter_context::Dict
7+
8+
function GFAssessState(gen_fn, choices, parameter_context)
9+
new(choices, 0.0, AddressVisitor(), gen_fn, parameter_context)
10+
end
611
end
712

8-
function GFAssessState(choices, params::Dict{Symbol,Any})
9-
GFAssessState(choices, 0., AddressVisitor(), params)
13+
get_parameter_store(state::GFAssessState) = get_julia_store(state.parameter_context)
14+
15+
get_parameter_id(state::GFAssessState, name::Symbol) = (state.active_gen_fn, name)
16+
17+
get_active_gen_fn(state::GFAssessState) = state.active_gen_fn
18+
19+
function set_active_gen_fn!(state::GFAssessState, gen_fn::GenerativeFunction)
20+
state.active_gen_fn = gen_fn
1021
end
1122

1223
function traceat(state::GFAssessState, dist::Distribution{T},
@@ -22,7 +33,7 @@ function traceat(state::GFAssessState, dist::Distribution{T},
2233
# update weight
2334
state.weight += logpdf(dist, retval, args...)
2435

25-
retval
36+
return retval
2637
end
2738

2839
function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U},
@@ -41,25 +52,19 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U},
4152
# update score
4253
state.weight += weight
4354

44-
retval
45-
end
46-
47-
function splice(state::GFAssessState, gen_fn::DynamicDSLFunction, args::Tuple)
48-
prev_params = state.params
49-
state.params = gen_fn.params
50-
retval = exec(gen_fn, state, args)
51-
state.params = prev_params
52-
retval
55+
return retval
5356
end
5457

55-
function assess(gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap)
56-
state = GFAssessState(choices, gen_fn.params)
58+
function assess(
59+
gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap;
60+
parameter_context=default_parameter_context)
61+
state = GFAssessState(gen_fn, choices, parameter_context)
5762
retval = exec(gen_fn, state, args)
5863

5964
unvisited = get_unvisited(get_visited(state.visitor), choices)
6065
if !isempty(unvisited)
6166
error("Assess did not visit the following constraint addresses:\n$unvisited")
6267
end
6368

64-
(state.weight, retval)
69+
return (state.weight, retval)
6570
end

src/dynamic/dynamic.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
3737
return "Gen DML generative function: $(gen_fn.julia_function)"
3838
end
3939

40+
function get_parameters(gen_fn::DynamicDSLFunction, parameter_context)
41+
# TODO for this, we need to walk the code... (and throw errors when the
42+
end
43+
4044
function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction}
4145
# pad args with default values, if available
4246
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)

src/dynamic/propose.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@ mutable struct GFProposeState
22
choices::DynamicChoiceMap
33
weight::Float64
44
visitor::AddressVisitor
5-
params::Dict{Symbol,Any}
5+
active_gen_fn::DynamicDSLFunction # mutated by splicing
6+
parameter_context::Dict
7+
8+
function GFProposeState(
9+
gen_fn::GenerativeFunction, parameter_context)
10+
return new(choicemap(), 0.0, AddressVisitor(), gen_fn, parameter_context)
11+
end
612
end
713

8-
function GFProposeState(params::Dict{Symbol,Any})
9-
GFProposeState(choicemap(), 0., AddressVisitor(), params)
14+
get_parameter_store(state::GFProposeState) = get_julia_store(state.parameter_context)
15+
16+
get_parameter_id(state::GFProposeState, name::Symbol) = (state.active_gen_fn, name)
17+
18+
get_active_gen_fn(state::GFProposeState) = state.active_gen_fn
19+
20+
function set_active_gen_fn!(state::GFProposeState, gen_fn::GenerativeFunction)
21+
state.active_gen_fn = gen_fn
1022
end
1123

1224
function traceat(state::GFProposeState, dist::Distribution{T},
@@ -47,16 +59,10 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
4759
retval
4860
end
4961

50-
function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple)
51-
prev_params = state.params
52-
state.params = gen_fn.params
53-
retval = exec(gen_fn, state, args)
54-
state.params = prev_params
55-
retval
56-
end
57-
58-
function propose(gen_fn::DynamicDSLFunction, args::Tuple)
59-
state = GFProposeState(gen_fn.params)
62+
function propose(
63+
gen_fn::DynamicDSLFunction, args::Tuple;
64+
parameter_context=default_parameter_context)
65+
state = GFProposeState(gen_fn, parameter_context)
6066
retval = exec(gen_fn, state, args)
61-
(state.choices, state.weight, retval)
67+
return (state.choices, state.weight, retval)
6268
end

src/dynamic/simulate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ function simulate(
6666
state = GFSimulateState(gen_fn, args, parameter_context)
6767
retval = exec(gen_fn, state, args)
6868
set_retval!(state.trace, retval)
69-
state.trace
69+
return state.trace
7070
end

src/optimization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export FixedStepGradientDescent
2020
export DecayStepGradientDescent
2121
export init_optimizer
2222
export apply_update!
23+
export CompositeOptimizer
2324

2425
export JuliaParameterStore
2526
export init_parameter!

src/static_ir/backprop.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
277277
error("Distribution $dist does not logpdf gradient for its output value")
278278
end
279279
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
280-
$(gradient_var(node)), retval_grad)))
280+
$(gradient_var(node)), $logpdf_grad[1])))
281281
end
282282
end
283283

src/static_ir/static_ir.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
4646
has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...)
4747
accepts_output_grad = ir.accepts_output_grad
4848

49+
show_str = "Gen SML generative function: $name"
50+
4951
gen_fn_defn = quote
5052
struct $gen_fn_type_name <: $(QuoteNode(StaticIRGenerativeFunction)){$return_type,$trace_type}
5153
end
@@ -62,7 +64,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
6264
return $(GlobalRef(Gen, :get_parameters))($(QuoteNode(ir)), gen_fn, context)
6365
end
6466
function Base.show(io::IO, ::MIME"text/plain", gen_fn::$gen_fn_type_name)
65-
return "Gen SML generative function: $name)"
67+
return $(QuoteNode(show_str))
6668
end
6769

6870
end

test/dsl/dynamic_dsl.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,6 @@ end
398398
conf = FixedStepGradientDescent(0.001)
399399
optimizer = init_optimizer(conf, [(foo, :theta)])
400400
apply_update!(optimizer)
401-
println(get_parameter_value((foo, :theta)))
402-
println(get_gradient((foo, :theta)))
403401
@test isapprox(get_parameter_value((foo, :theta)), 0.001)
404402
@test isapprox(get_gradient((foo, :theta)), 0.0)
405403
end

test/inference/train.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@
8585
value = get_parameter_value((student, name))
8686

8787
# evaluate total log density at value + dx
88-
set_param!(student, name, value + dx)
88+
init_parameter!((student, name), value + dx)
89+
8990
lpdf_pos = 0.
9091
for i=1:minibatch_size
9192
(incr, _) = assess(student, inputs[i], constraints[i])
9293
lpdf_pos += incr
9394
end
9495

9596
# evaluate total log density at value - dx
96-
set_param!(student, name, value - dx)
97+
init_parameter!((student, name), value - dx)
9798
lpdf_neg = 0.
9899
for i=1:minibatch_size
99100
(incr, _) = assess(student, inputs[i], constraints[i])
@@ -103,11 +104,11 @@
103104
expected = (lpdf_pos - lpdf_neg) / (2 * dx)
104105
@test isapprox(actual, expected, atol=1e-4)
105106

106-
set_param!(student, name, value)
107+
init_parameter!((student, name), value)
107108
end
108109

109110
# use stochastic gradient descent
110-
optimizer = CompositeOptimizer(GradientDescent(0.01, 1000000), student)
111+
optimizer = CompositeOptimizer(DecayStepGradientDescent(0.01, 1000000), student)
111112
train!(student, data_generator, optimizer,
112113
num_epoch=2000, epoch_size=50, num_minibatch=1, minibatch_size=50,
113114
verbose=false)

test/modeling_library/map.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
return z
77
end
88

9-
set_param!(foo, :std, 1.)
9+
init_parameter!((foo, :std), 1.0)
1010

1111
bar = Map(foo)
1212
xs = [1.0, 2.0, 3.0, 4.0]

0 commit comments

Comments
 (0)