Skip to content

Commit ad05f34

Browse files
committed
checkpoint on getting tests to pass again
1 parent 4b5d804 commit ad05f34

25 files changed

+569
-609
lines changed

src/Gen.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ include("dynamic/dynamic.jl")
6464
# static IR generative function
6565
include("static_ir/static_ir.jl")
6666

67-
# optimization for built-in generative functions (dynamic and static IR)
68-
include("builtin_optimization.jl")
69-
7067
# DSLs for defining dynamic embedded and static IR generative functions
7168
# 'Dynamic DSL' and 'Static DSL'
7269
include("dsl/dsl.jl")

src/builtin_optimization.jl

Lines changed: 0 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -1,235 +0,0 @@
1-
#############################
2-
3-
# primitives for in-place gradient accumulation
4-
5-
function in_place_add!(param::Array, increment::Array, scale_factor::Real)
6-
# NOTE: it ignores the scale_factor, because it is not a parameter...
7-
# scale factors only affect parameters
8-
# TODO this is potentially very confusing!
9-
@simd for i in 1:length(param)
10-
param[i] += increment[i]
11-
end
12-
return param
13-
end
14-
15-
function in_place_add!(param::Array, increment::Array)
16-
@inbounds @simd for i in 1:length(param)
17-
param[i] += increment[i]
18-
end
19-
return param
20-
end
21-
22-
function in_place_add!(param::Real, increment::Real, scale_factor::Real)
23-
return param + increment
24-
end
25-
26-
function in_place_add!(param::Real, increment::Real)
27-
return param + increment
28-
end
29-
30-
mutable struct ThreadsafeAccumulator{T}
31-
value::T
32-
lock::ReentrantLock
33-
end
34-
35-
ThreadsafeAccumulator(value) = ThreadsafeAccumulator(value, ReentrantLock())
36-
37-
# TODO not threadsafe
38-
function get_current_value(accum::ThreadsafeAccumulator)
39-
return accum.value
40-
end
41-
42-
function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real, scale_factor::Real)
43-
lock(param.lock)
44-
try
45-
param.value = param.value + increment * scale_factor
46-
finally
47-
unlock(param.lock)
48-
end
49-
return param
50-
end
51-
52-
function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real)
53-
lock(param.lock)
54-
try
55-
param.value = param.value + increment
56-
finally
57-
unlock(param.lock)
58-
end
59-
return param
60-
end
61-
62-
function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment, scale_factor::Real)
63-
lock(param.lock)
64-
try
65-
@simd for i in 1:length(param.value)
66-
param.value[i] += increment[i] * scale_factor
67-
end
68-
finally
69-
unlock(param.lock)
70-
end
71-
return param
72-
end
73-
74-
function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment)
75-
lock(param.lock)
76-
try
77-
@simd for i in 1:length(param.value)
78-
param.value[i] += increment[i]
79-
end
80-
finally
81-
unlock(param.lock)
82-
end
83-
return param
84-
end
85-
86-
#############################
87-
88-
89-
90-
"""
91-
set_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value)
92-
93-
Set the value of a trainable parameter of the generative function.
94-
95-
NOTE: Does not update the gradient accumulator value.
96-
"""
97-
function set_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value)
98-
return gf.params[name] = value
99-
end
100-
101-
"""
102-
value = get_param(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
103-
104-
Get the current value of a trainable parameter of the generative function.
105-
"""
106-
function get_param(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
107-
return gf.params[name]
108-
end
109-
110-
"""
111-
value = get_param_grad(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
112-
113-
Get the current value of the gradient accumulator for a trainable parameter of the generative function.
114-
115-
Not threadsafe.
116-
"""
117-
function get_param_grad(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
118-
try
119-
val = gf.params_grad[name] # the accumulator
120-
return get_current_value(val)
121-
catch KeyError
122-
error("parameter $name not found")
123-
end
124-
return val
125-
end
126-
127-
"""
128-
zero_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
129-
130-
Reset the gradient accumlator for a trainable parameter of the generative function to all zeros.
131-
132-
Not threadsafe.
133-
"""
134-
function zero_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol)
135-
gf.params_grad[name] = ThreadsafeAccumulator(zero(gf.params[name])) # TODO avoid allocation?
136-
return gf.params_grad[name]
137-
end
138-
139-
"""
140-
set_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value)
141-
142-
Set the gradient accumlator for a trainable parameter of the generative function.
143-
144-
Not threadsafe.
145-
"""
146-
function set_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value)
147-
gf.params_grad[name] = ThreadsafeAccumulator(grad_value)
148-
return grad_value
149-
end
150-
151-
# TODO document me; it is threadsafe..
152-
function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment, scale_factor)
153-
in_place_add!(gf.params_grad[name], increment, scale_factor)
154-
end
155-
156-
# TODO document me; it is threadsafe..
157-
function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment)
158-
in_place_add!(gf.params_grad[name], increment)
159-
end
160-
161-
162-
163-
"""
164-
init_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value)
165-
166-
Initialize the the value of a named trainable parameter of a generative function.
167-
168-
Also generates the gradient accumulator for that parameter to `zero(value)`.
169-
170-
Example:
171-
```julia
172-
init_param!(foo, :theta, 0.6)
173-
```
174-
"""
175-
function init_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value)
176-
set_param!(gf, name, value)
177-
zero_param_grad!(gf, name)
178-
end
179-
180-
get_params(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}) = keys(gf.params)
181-
182-
export set_param!, get_param, get_param_grad, zero_param_grad!, set_param_grad!, init_param!, increment_param_grad!
183-
184-
#########################################
185-
# gradient descent with fixed step size #
186-
#########################################
187-
188-
mutable struct FixedStepGradientDescentBuiltinDSLState
189-
step_size::Float64
190-
gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}
191-
param_list::Vector
192-
end
193-
194-
function init_update_state(conf::FixedStepGradientDescent,
195-
gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector)
196-
FixedStepGradientDescentBuiltinDSLState(conf.step_size, gen_fn, param_list)
197-
end
198-
199-
function apply_update!(state::FixedStepGradientDescentBuiltinDSLState)
200-
for param_name in state.param_list
201-
value = get_param(state.gen_fn, param_name)
202-
grad = get_param_grad(state.gen_fn, param_name)
203-
set_param!(state.gen_fn, param_name, value + grad * state.step_size)
204-
zero_param_grad!(state.gen_fn, param_name)
205-
end
206-
end
207-
208-
####################
209-
# gradient descent #
210-
####################
211-
212-
mutable struct GradientDescentBuiltinDSLState
213-
step_size_init::Float64
214-
step_size_beta::Float64
215-
gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}
216-
param_list::Vector
217-
t::Int
218-
end
219-
220-
function init_update_state(conf::GradientDescent,
221-
gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector)
222-
GradientDescentBuiltinDSLState(conf.step_size_init, conf.step_size_beta,
223-
gen_fn, param_list, 1)
224-
end
225-
226-
function apply_update!(state::GradientDescentBuiltinDSLState)
227-
step_size = state.step_size_init * (state.step_size_beta + 1) / (state.step_size_beta + state.t)
228-
for param_name in state.param_list
229-
value = get_param(state.gen_fn, param_name)
230-
grad = get_param_grad(state.gen_fn, param_name)
231-
set_param!(state.gen_fn, param_name, value + grad * step_size)
232-
zero_param_grad!(state.gen_fn, param_name)
233-
end
234-
state.t += 1
235-
end

src/dynamic/backprop.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ mutable struct GFBackpropParamsState
4242
visitor::AddressVisitor
4343
scale_factor::Float64
4444
active_gen_fn::GenerativeFunction
45-
tracked_params::Dict{ParameterID,Any}
45+
tracked_params::Dict{Tuple{GenerativeFunction,Symbol},Any}
4646

4747
function GFBackpropParamsState(trace::DynamicDSLTrace, tape, scale_factor)
48-
tracked_params = Dict{ParameterID,Any}()
48+
tracked_params = Dict{Tuple{GenerativeFunction,Symbol},Any}()
4949
store = get_parameter_store(trace)
5050
gen_fn = get_gen_fn(trace)
5151
for (name, value) in get_local_parameters(store, gen_fn)
@@ -178,9 +178,10 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_
178178
reverse_pass!(tape)
179179

180180
# increment the gradient accumulators for trainable parameters in scope
181+
store = get_parameter_store(trace)
181182
for ((active_gen_fn, name), tracked) in state.tracked_params
182-
parameter_id = (active_gen_fn, parameter_id)
183-
increment_gradient!(store, parameter_id, deriv(tracked), state.scale_factor)
183+
parameter_id = (active_gen_fn, name)
184+
increment_gradient!(parameter_id, deriv(tracked), state.scale_factor, store)
184185
end
185186

186187
# return gradients with respect to arguments with gradients, or nothing
@@ -217,6 +218,16 @@ function GFBackpropTraceState(trace, selection, tape)
217218
get_gen_fn(trace))
218219
end
219220

221+
get_parameter_store(state::GFBackpropTraceState) = get_parameter_store(state.trace)
222+
223+
get_parameter_id(state::GFBackpropTraceState, name::Symbol) = (state.active_gen_fn, name)
224+
225+
get_active_gen_fn(state::GFBackpropTraceState) = state.active_gen_fn
226+
227+
function set_active_gen_fn!(state::GFBackpropTraceState, gen_fn::GenerativeFunction)
228+
state.active_gen_fn = gen_fn
229+
end
230+
220231
function fill_submaps!(
221232
map::DynamicChoiceMap,
222233
tracked_trie::Trie{Any,Union{TrackedReal,TrackedArray}},
@@ -304,15 +315,6 @@ function traceat(state::GFBackpropTraceState, gen_fn::GenerativeFunction{T,U},
304315
retval_maybe_tracked
305316
end
306317

307-
function splice(state::GFBackpropTraceState, gen_fn::DynamicDSLFunction,
308-
args_maybe_tracked::Tuple)
309-
prev_gen_fn = state.active_gen_fn
310-
state.active_gen_fn = gen_fn
311-
retval = exec(gen_fn, state, args)
312-
state.active_gen_fn = prev_gen_fn
313-
return retval
314-
end
315-
316318
@noinline function ReverseDiff.special_reverse_exec!(
317319
instruction::ReverseDiff.SpecialInstruction{BackpropTraceRecord})
318320
record = instruction.func

src/dynamic/dynamic.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,22 @@ function DynamicDSLFunction(arg_types::Vector{Type},
2929
has_argument_grads, accepts_output_grad)
3030
end
3131

32-
function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction}
32+
function Base.show(io::IO, gen_fn::DynamicDSLFunction)
33+
return "Gen DML generative function: $(gen_fn.julia_function)"
34+
end
35+
36+
function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
37+
return "Gen DML generative function: $(gen_fn.julia_function)"
38+
end
39+
40+
function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction}
3341
# pad args with default values, if available
3442
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)
3543
defaults = gen_fn.arg_defaults[length(args)+1:end]
3644
defaults = map(x -> something(x), defaults)
3745
args = Tuple(vcat(collect(args), defaults))
3846
end
39-
return DynamicDSLTrace{T}(gen_fn, args)
47+
return DynamicDSLTrace{T}(gen_fn, args, parameter_store)
4048
end
4149

4250
accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad
@@ -58,6 +66,14 @@ function exec(gen_fn::DynamicDSLFunction, state, args::Tuple)
5866
gen_fn.julia_function(state, args...)
5967
end
6068

69+
function splice(state, gen_fn::DynamicDSLFunction, args::Tuple)
70+
prev_gen_fn = get_active_gen_fn(state)
71+
state.active_gen_fn = gen_fn
72+
retval = exec(gen_fn, state, args)
73+
set_active_gen_fn!(state, prev_gen_fn)
74+
return retval
75+
end
76+
6177
# whether there is a gradient of score with respect to each argument
6278
# it returns 'nothing' for those arguemnts that don't have a derivatice
6379
has_argument_grads(gen::DynamicDSLFunction) = gen.has_argument_grads
@@ -104,7 +120,7 @@ end
104120
function read_param(state, name::Symbol)
105121
parameter_id = get_parameter_id(state, name)
106122
store = get_parameter_store(state)
107-
return get_parameter_value(store, parameter_id)
123+
return get_parameter_value(parameter_id, store)
108124
end
109125

110126
##################

0 commit comments

Comments
 (0)