Skip to content

Commit 4b5d804

Browse files
committed
checkpoint on SML changes
1 parent 6e85f3f commit 4b5d804

File tree

6 files changed

+122
-107
lines changed

6 files changed

+122
-107
lines changed

src/optimization.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,27 @@ function increment_gradient!(
293293
return nothing
294294
end
295295

296-
# TODO docstring (thread-safe)
297-
function increment_gradient!(
298-
store::JuliaParameterStore, id::JuliaParameterID,
299-
increment)
296+
function get_gradient_accumulator(store::JuliaParameterStore, id::JuliaParameterID)
300297
(gen_fn, name) = id
301298
try
302-
in_place_add!(store.gradient_accumulators[gen_fn][name], increment)
299+
return store.gradient_accumulators[gen_fn][name]
303300
catch KeyError
304301
@error "parameter not initialized: $id"
305302
rethrow()
306303
end
304+
end
305+
306+
# TODO docstring (thread-safe)
307+
function increment_gradient!(
308+
store::JuliaParameterStore, id::JuliaParameterID,
309+
increment)
310+
accumulator = get_gradient_accumulator(store, id)
311+
in_place_add!(accumulator, increment)
307312
return nothing
308313
end
309314

310315
# TODO docstring (not thread-safe)
316+
311317
function get_parameter_value(store::JuliaParameterStore, id::JuliaParameterID)
312318
(gen_fn, name) = id
313319
try

src/static_ir/backprop.jl

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,32 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix)
2121
const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg")
2222
maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i")
2323

24-
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode)
25-
# TODO: only need to mark it if we are doing backprop params
24+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropParamsMode)
2625
push!(fwd_marked, node)
2726
end
2827

29-
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode)
28+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropTraceMode)
29+
end
30+
31+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode, mode)
3032
if node.compute_grad
3133
push!(fwd_marked, node)
3234
end
3335
end
3436

35-
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode)
37+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode, mode)
3638
if any(input_node in fwd_marked for input_node in node.inputs)
3739
push!(fwd_marked, node)
3840
end
3941
end
4042

41-
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode)
43+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode, mode)
4244
if node in selected_choices
4345
push!(fwd_marked, node)
4446
end
4547
end
4648

47-
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode)
49+
function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode, mode)
4850
if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs)
4951
push!(fwd_marked, node)
5052
end
@@ -73,6 +75,9 @@ end
7375

7476
function back_pass!(back_marked, node::GenerativeFunctionCallNode)
7577
# the logpdf of every generative function call is a SINK
78+
# (we could ask whether the generative function is deterministic or not
79+
# as a perforance optimization, because only stochsatic generative functions
80+
# actually have a non-trivial logpdf)
7681
for (input_node, has_grad) in zip(node.inputs, has_argument_grads(node.generative_function))
7782
if has_grad
7883
push!(back_marked, input_node)
@@ -82,17 +87,15 @@ end
8287

8388
function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode)
8489
if node in back_marked
85-
push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace),
86-
$(QuoteNode(node.name)))))
90+
push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name)))))
8791
end
8892

8993
if node in fwd_marked && node in back_marked
9094

9195
# initialize gradient to zero
9296
# NOTE: we are avoiding allocating a new gradient accumulator for this function
9397
# instead, we are using the threadsafe gradient accumulator directly..
94-
#push!(stmts, :($(gradient_var(node)) = zero($(node.name))))
95-
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gen_fn))(trace).params_grad[$(QuoteNode(node.name))]))
98+
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gradient_accumulator))(trace, $(QuoteNode(node.name)))))
9699
end
97100
end
98101

@@ -106,7 +109,7 @@ end
106109

107110
function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
108111

109-
if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs)
112+
if (node in fwd_marked) && (node in back_marked)
110113

111114
# tracked forward execution
112115
tape = tape_var(node)
@@ -128,27 +131,20 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
128131

129132
# initialize gradient to zero
130133
push!(stmts, :($(gradient_var(node)) = zero($(node.name))))
131-
else
132134

133-
# regular forward execution.
135+
elseif node in back_marked
134136

135-
# we need the value for initializing gradient to zero (to get the type
136-
# and e.g. shape), and for reference by other nodes during
137-
# back_codegen! we could be more selective about which JuliaNodes need
138-
# to be evalutaed, that is a performance optimization for the future
137+
# regular forward execution.
139138
args = map((input_node) -> input_node.name, node.inputs)
140139
push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...))))
141140
end
142141
end
143142

144143
function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode)
145-
# for reference by other nodes during back_codegen!
146-
# could performance optimize this away
147-
push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node))))
148-
149144
# every random choice is in back_marked, since it affects it logpdf, but
150145
# also possibly due to other downstream usage of the value
151146
@assert node in back_marked
147+
push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node))))
152148

153149
if node in fwd_marked
154150
# the only way we are fwd_marked is if this choice was selected
@@ -160,14 +156,16 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode)
160156
end
161157

162158
function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode)
163-
# for reference by other nodes during back_codegen!
164-
# could performance optimize this away
165-
subtrace_fieldname = get_subtrace_fieldname(node)
166-
push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname)))
159+
160+
if node in back_marked
161+
# for reference by other nodes during back_codegen!
162+
subtrace_fieldname = get_subtrace_fieldname(node)
163+
push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname)))
164+
end
167165

168166
# NOTE: we will still potentially run choice_gradients recursively on the generative function,
169167
# we just might not use its return value gradient.
170-
if node in fwd_marked && node in back_marked
168+
if (node in fwd_marked) && (node in back_marked)
171169
# we are fwd_marked if an input was fwd_marked, or if we were selected internally
172170
push!(stmts, :($(gradient_var(node)) = zero($(node.name))))
173171
end
@@ -181,15 +179,6 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
181179
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing")))
182180
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
183181
$(gradient_var(node)), retval_grad, scale_factor)))
184-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
185-
end
186-
187-
if node in fwd_marked && node in back_marked
188-
#NOTE: unecessary, because we accumulated in-place already
189-
#push!(stmts, :($(QuoteNode(increment_param_grad!))(trace.$static_ir_gen_fn_ref,
190-
#$(QuoteNode(node.name)),
191-
#$(gradient_var(node)),
192-
#scale_factor)))
193182
end
194183
end
195184

@@ -199,8 +188,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
199188
if node === ir.return_node && node in fwd_marked
200189
@assert node in back_marked
201190
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
202-
$(gradient_var(node)), retval_grad, scale_factor)))
203-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
191+
$(gradient_var(node)), retval_grad)))
204192
end
205193
end
206194

@@ -209,8 +197,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
209197
if node === ir.return_node && node in fwd_marked
210198
@assert node in back_marked
211199
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
212-
$(gradient_var(node)), retval_grad, scale_factor)))
213-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
200+
$(gradient_var(node)), retval_grad)))
214201
end
215202
if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs)
216203

@@ -222,9 +209,13 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
222209
for (i, input_node) in enumerate(node.inputs)
223210
if input_node in fwd_marked
224211
arg_maybe_tracked = maybe_tracked_arg_var(node, i)
225-
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
226-
$(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor)))
227-
#push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked)))
212+
if isa(input_node, TrainableParameterNode)
213+
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
214+
$(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor)))
215+
else
216+
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
217+
$(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked))))
218+
end
228219
end
229220
end
230221
end
@@ -246,9 +237,15 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke
246237
if !has_argument_grads(node.dist)[i]
247238
error("Distribution $(node.dist) does not have logpdf gradient for argument $i")
248239
end
249-
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
250-
$(gradient_var(input_node)), $logpdf_grad[$(QuoteNode(i+1))], scale_factor)))
251-
#push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))]))
240+
input_node_grad = gradient_var(input_node)
241+
increment = :($logpdf_grad[$(QuoteNode(i+1))])
242+
if isa(input_node, TrainableParameterNode)
243+
push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))(
244+
$input_node_grad, $increment, scale_factor)))
245+
else
246+
push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))(
247+
$input_node_grad, $increment)))
248+
end
252249
end
253250
end
254251

@@ -257,7 +254,6 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke
257254
@assert node in back_marked
258255
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
259256
$(gradient_var(node)), retval_grad, scale_factor)))
260-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
261257
end
262258
end
263259

@@ -274,8 +270,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
274270
error("Distribution $dist does not logpdf gradient for its output value")
275271
end
276272
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
277-
$(gradient_var(node)), retval_grad, scale_factor)))
278-
#push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1]))
273+
$(gradient_var(node)), retval_grad)))
279274
end
280275
end
281276

@@ -292,8 +287,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
292287
if node === ir.return_node && node in fwd_marked
293288
@assert node in back_marked
294289
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
295-
$(gradient_var(node)), retval_grad, scale_factor)))
296-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
290+
$(gradient_var(node)), retval_grad)))
297291
end
298292

299293
if node in fwd_marked
@@ -316,9 +310,10 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
316310
for (i, input_node) in enumerate(node.inputs)
317311
if input_node in fwd_marked
318312
@assert input_node in back_marked # this ensured its gradient will have been initialized
313+
input_node_grad = gradient_var(input_node)
314+
increment = :($input_grads[$(QuoteNode(i))])
319315
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
320-
$(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor)))
321-
#push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))]))
316+
$input_node_grad, $increment)))
322317
end
323318
end
324319

@@ -332,8 +327,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
332327
if node === ir.return_node && node in fwd_marked
333328
@assert node in back_marked
334329
push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))(
335-
$(gradient_var(node)), retval_grad, scale_factor)))
336-
#push!(stmts, :($(gradient_var(node)) += retval_grad))
330+
$(gradient_var(node)), retval_grad)))
337331
end
338332

339333
if node in fwd_marked
@@ -347,9 +341,15 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
347341
for (i, (input_node, has_grad)) in enumerate(zip(node.inputs, has_argument_grads(node.generative_function)))
348342
if input_node in fwd_marked && has_grad
349343
@assert input_node in back_marked # this ensured its gradient will have been initialized
350-
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
351-
$(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor)))
352-
#push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))]))
344+
input_node_grad = gradient_var(input_node)
345+
increment = :($input_grads[$(QuoteNode(i))])
346+
if isa(input_node, TrainableParameterNode)
347+
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
348+
$input_node_grad, $increment, scale_factor)))
349+
else
350+
push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))(
351+
$input_node_grad, $increment)))
352+
end
353353
end
354354
end
355355
end
@@ -425,16 +425,14 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
425425
return quote choice_gradients(trace, StaticSelection(selection), retval_grad) end
426426
end
427427

428-
push!(stmts, :(scale_factor = NaN))
429-
430428
ir = get_ir(gen_fn_type)
431429
selected_choices = get_selected_choices(schema, ir)
432430
selected_calls = get_selected_calls(schema, ir)
433431

434432
# forward marking pass
435433
fwd_marked = Set{StaticIRNode}()
436434
for node in ir.nodes
437-
fwd_pass!(selected_choices, selected_calls, fwd_marked, node)
435+
fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode())
438436
end
439437

440438
# backward marking pass
@@ -489,13 +487,13 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T},
489487
selected_calls = Set{GenerativeFunctionCallNode}(
490488
node for node in ir.nodes if isa(node, GenerativeFunctionCallNode))
491489

492-
# forward marking pass
490+
# forward marking pass (propagate forward from 'sources')
493491
fwd_marked = Set{StaticIRNode}()
494492
for node in ir.nodes
495-
fwd_pass!(selected_choices, selected_calls, fwd_marked, node)
493+
fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode())
496494
end
497495

498-
# backward marking pass
496+
# backward marking pass (propagate backwards from 'sinks')
499497
back_marked = Set{StaticIRNode}()
500498
push!(back_marked, ir.return_node)
501499
for node in reverse(ir.nodes)
@@ -508,12 +506,15 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T},
508506
arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes]
509507
push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace)))
510508

511-
# forward code-generation pass (initialize gradients to zero, create needed references)
509+
# forward code-generation pass
510+
# any node that is backward-marked creates a variable for its current value
511+
# any node that is forward-marked and backwards marked initializes a gradient variable
512512
for node in ir.nodes
513513
fwd_codegen!(stmts, fwd_marked, back_marked, node)
514514
end
515515

516-
# backward code-generation pass (increment gradients)
516+
# backward code-generation pass
517+
# any node that is forward-marked and backwards marked increments its gradient variable
517518
for node in reverse(ir.nodes)
518519
back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode())
519520
end

src/static_ir/generate.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ end
66
function process!(::StaticIRGenerateState, node, options) end
77

88
function process!(state::StaticIRGenerateState, node::TrainableParameterNode, options)
9-
push!(state.stmts, :($(node.name) = $(QuoteNode(get_param))(gen_fn, $(QuoteNode(node.name)))))
9+
push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name)))))
1010
end
1111

1212
function process!(state::StaticIRGenerateState, node::ArgumentNode, options)
@@ -84,6 +84,7 @@ function codegen_generate(gen_fn_type::Type{T}, args,
8484
push!(stmts, :($total_noise_fieldname = 0.))
8585
push!(stmts, :($weight = 0.))
8686
push!(stmts, :($num_nonempty_fieldname = 0))
87+
push!(stmts, :($parameter_store_fieldname = $(QuoteNode(get_julia_store))(parameter_context)))
8788

8889
# unpack arguments
8990
arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes]
@@ -109,8 +110,10 @@ function codegen_generate(gen_fn_type::Type{T}, args,
109110
end
110111

111112
push!(generated_functions, quote
112-
@generated function $(GlobalRef(Gen, :generate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)),
113-
args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)))
113+
@generated function $(GlobalRef(Gen, :generate))(
114+
gen_fn::$(QuoteNode(StaticIRGenerativeFunction)),
115+
args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap));
116+
parameter_context=default_parameter_context)
114117
$(QuoteNode(codegen_generate))(gen_fn, args, constraints)
115118
end
116119
end)

0 commit comments

Comments
 (0)