Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3e4f695
Initial work on a Switch combinator.
femtomc Nov 17, 2020
bd4f830
Initial implementation of propose and generate.
femtomc Nov 17, 2020
374a7b0
Added implementaton of simulate.
femtomc Nov 17, 2020
5872593
Corrected some bugs with Bernoulli vs bernoulli.
femtomc Nov 17, 2020
9c0a9f2
Added assess implementation.
femtomc Nov 17, 2020
95baf07
Split into two combinators: Switch and WithProbability implementations.
femtomc Nov 18, 2020
29b7797
Working on Switch update and regenerate.
femtomc Nov 18, 2020
3e6e307
Added Switch update and regenerate.
femtomc Nov 18, 2020
7929b86
Added Switch update and regenerate - working out kinks in update.
femtomc Nov 18, 2020
73618a1
update and regenerate appear to be computing the correct ratios. To c…
femtomc Nov 18, 2020
252413f
Fixed generate index type bug.
femtomc Nov 18, 2020
ac3528e
Branch dispatch done using diff types.
femtomc Nov 18, 2020
eaf3327
Branch dispatch done using diff types.
femtomc Nov 18, 2020
6d58aac
Branch dispatch done using diff types.
femtomc Nov 18, 2020
e413e9c
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
435493f
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
32fec4f
Idiomatic check for EmptyChoiceMap.
femtomc Nov 18, 2020
bb767e7
Working on backprop - seems simple? Could it really be?
femtomc Nov 18, 2020
a35e2e7
Extracting WithProb combinator into another PR.
femtomc Nov 18, 2020
562667e
Testing backprop.
femtomc Nov 19, 2020
b74a071
Fixed backprop - was thinking in Zygote lang. Gradients appear to be …
femtomc Nov 19, 2020
915811d
Merge branch 'master' of https://github.com/probcomp/Gen.jl into 2020…
femtomc Nov 19, 2020
849d61e
Added docstring and docs example.
femtomc Nov 19, 2020
adf73a5
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 19, 2020
dfe0125
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 20, 2020
3717d65
Tests for everything but gradients - working on gradients now.
femtomc Nov 20, 2020
cb62fb5
Last tests I need to write: accumulate_param_gradients!
femtomc Nov 20, 2020
97473d0
Added accumulate_param_gradients! tests.
femtomc Nov 20, 2020
176b9e9
Reverted particle filter fix - will be handled in another issue.
femtomc Nov 20, 2020
0465965
Renamed mix field of Switch generative function to branches to more a…
femtomc Nov 22, 2020
43c7274
Addressed review comments. Added docstrings where necessary. Correcte…
femtomc Dec 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/modeling_library/cond.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ------------ WithProbability trace ------------ #

struct WithProbabilityTrace{T1, T2, Tr} <: Trace
gen_fn::GenerativeFunction{Union{T1, T2}, Tr}
p::Float64
cond::Bool
branch::Tr
retval::Union{T1, T2}
args::Tuple
score::Float64
noise::Float64
end

@inline function get_choices(tr::WithProbabilityTrace)
choices = choicemap()
set_submap!(choices, :branch, get_choices(tr.branch))
set_value!(choices, :cond, tr.cond)
choices
end
@inline get_retval(tr::WithProbabilityTrace) = tr.retval
@inline get_args(tr::WithProbabilityTrace) = tr.args
@inline get_score(tr::WithProbabilityTrace) = tr.score
@inline get_gen_fn(tr::WithProbabilityTrace) = tr.gen_fn

@inline function Base.getindex(tr::WithProbabilityTrace, addr::Pair)
(first, rest) = addr
subtr = getfield(trace, first)
subtrace[rest]
end
@inline Base.getindex(tr::WithProbabilityTrace, addr::Symbol) = getfield(trace, addr)

function project(tr::WithProbabilityTrace, selection::Selection)
sum(map([:cond, :branch]) do k
subselection = selection[k]
project(getindex(tr, k), subselection)
end)
end
project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise

# ------------ Switch trace ------------ #

struct SwitchTrace{T} <: Trace
gen_fn::GenerativeFunction{T}
index::Int
branch::Trace
retval::T
args::Tuple
score::Float64
noise::Float64
end

@inline get_choices(tr::SwitchTrace) = get_choices(tr.branch)
@inline get_retval(tr::SwitchTrace) = tr.retval
@inline get_args(tr::SwitchTrace) = tr.args
@inline get_score(tr::SwitchTrace) = tr.score
@inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn

@inline function Base.getindex(tr::SwitchTrace, addr::Pair)
(first, rest) = addr
subtr = getfield(trace, first)
subtrace[rest]
end
@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr)

@inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection)
@inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise
5 changes: 5 additions & 0 deletions src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ include("dist_dsl/dist_dsl.jl")
# code shared by vector-shaped combinators
include("vector.jl")

# traces for with prob/switch combinator
include("cond.jl")

# built-in generative function combinators
include("choice_at/choice_at.jl")
include("call_at/call_at.jl")
include("map/map.jl")
include("unfold/unfold.jl")
include("recurse/recurse.jl")
include("switch/switch.jl")
include("with_prob/with_prob.jl")

#############################################################
# abstractions for constructing custom generative functions #
Expand Down
25 changes: 25 additions & 0 deletions src/modeling_library/switch/assess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
mutable struct SwitchAssessState{T}
weight::Float64
retval::T
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
choices::ChoiceMap,
state::SwitchAssessState{T}) where {C, N, K, T}
(weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices)
state.weight += weight
state.retval = retval
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)

function assess(gen_fn::Switch{C, N, K, T},
args::Tuple,
choices::ChoiceMap) where {C, N, K, T}
index = args[1]
state = SwitchAssessState{T}(0.0)
process!(gen_fn, index, args[2 : end], choices, state)
return state.weight, state.retval
end
34 changes: 34 additions & 0 deletions src/modeling_library/switch/generate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
mutable struct SwitchGenerateState{T}
score::Float64
noise::Float64
weight::Float64
index::Int
subtrace::Trace
retval::T
SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
choices::ChoiceMap,
state::SwitchGenerateState{T}) where {C, N, K, T}

(subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices)
state.index = index
state.subtrace = subtrace
state.weight += weight
state.retval = get_retval(subtrace)
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)

function generate(gen_fn::Switch{C, N, K, T},
args::Tuple,
choices::ChoiceMap) where {C, N, K, T}

index = args[1]
state = SwitchGenerateState{T}(0.0, 0.0, 0.0)
process!(gen_fn, index, args[2 : end], choices, state)
return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight
end
29 changes: 29 additions & 0 deletions src/modeling_library/switch/propose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
mutable struct SwitchProposeState{T}
choices::DynamicChoiceMap
weight::Float64
retval::T
SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
state::SwitchProposeState{T}) where {C, N, K, T}

(submap, weight, retval) = propose(getindex(gen_fn.mix, index), args)
state.choices = submap
state.weight += weight
state.retval = retval
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)

function propose(gen_fn::Switch{C, N, K, T},
args::Tuple) where {C, N, K, T}

index = args[1]
choices = choicemap()
state = SwitchProposeState{T}(choices, 0.0)
process!(gen_fn, index, args[2:end], state)
return state.choices, state.weight, state.retval
end
59 changes: 59 additions & 0 deletions src/modeling_library/switch/regenerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
mutable struct SwitchRegenerateState{T}
weight::Float64
score::Float64
noise::Float64
prev_trace::Trace
trace::Trace
index::Int
retdiff::Diff
SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::UnknownChange,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to use Diff instead of UnknownChange here. There might be other Diff types that could be passed that are intermediate between UnknownChange and NoChange (e.g. there is an IntDiff already, which tracks the arithmetic difference between two integers). This method should apply to anything that's not a NoChange.

args::Tuple,
kernel_argdiffs::Tuple,
selection::Selection,
state::SwitchRegenerateState{T}) where {C, N, K, T}
merged = get_choices(state.prev_trace)
branch_fn = getfield(gen_fn.mix, index)
new_trace, weight = generate(branch_fn, args, merged)
retdiff = UnknownChange()
weight -= get_score(state.prev_trace)
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.retdiff = retdiff
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::NoChange,
args::Tuple,
kernel_argdiffs::Tuple,
selection::Selection,
state::SwitchRegenerateState{T}) where {C, N, K, T}
new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection)
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.retdiff = retdiff
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state)

function regenerate(trace::SwitchTrace{T},
args::Tuple,
argdiffs::Tuple,
selection::Selection) where T
gen_fn = trace.gen_fn
index, index_argdiff = args[1], argdiffs[1]
state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace)
process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state)
return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff
end
32 changes: 32 additions & 0 deletions src/modeling_library/switch/simulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
mutable struct SwitchSimulateState{T}
score::Float64
noise::Float64
index::Int
subtrace::Trace
retval::T
SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
state::SwitchSimulateState{T}) where {C, N, K, T}
local retval::T
subtrace = simulate(getindex(gen_fn.mix, index), args)
state.index = index
state.noise += project(subtrace, EmptySelection())
state.subtrace = subtrace
state.score += get_score(subtrace)
state.retval = get_retval(subtrace)
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)

function simulate(gen_fn::Switch{C, N, K, T},
args::Tuple) where {C, N, K, T}

index = args[1]
state = SwitchSimulateState{T}(0.0, 0.0)
process!(gen_fn, index, args[2 : end], state)
SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise)
end
36 changes: 36 additions & 0 deletions src/modeling_library/switch/switch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace}
mix::NTuple{N, GenerativeFunction{T}}
cases::Dict{C, Int}
function Switch(gen_fns::GenerativeFunction...)
@assert !isempty(gen_fns)
rettype = get_return_type(getindex(gen_fns, 1))
new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}())
end
function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C
@assert !isempty(gen_fns)
rettype = get_return_type(getindex(gen_fns, 1))
new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d)
end
end

export Switch

has_argument_grads(switch_fn::Switch) = all(has_argument_grads, switch.mix)
accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch.mix)

function (gen_fn::Switch)(index::Int, args...)
(_, _, retval) = propose(gen_fn, (index, args...))
retval
end

function (gen_fn::Switch{C})(index::C, args...) where C
(_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...))
retval
end

include("assess.jl")
include("propose.jl")
include("simulate.jl")
include("generate.jl")
include("update.jl")
include("regenerate.jl")
70 changes: 70 additions & 0 deletions src/modeling_library/switch/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
mutable struct SwitchUpdateState{T}
weight::Float64
score::Float64
noise::Float64
prev_trace::Trace
trace::Trace
index::Int
discard::ChoiceMap
updated_retdiff::Diff
SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
end

function process!(gen_fn::Switch{C, N, K, T},
Copy link
Contributor Author

@femtomc femtomc Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcoct This seems correct for the "namespace merging" semantics. In update, if I switch branches, I generate with a merge of the previous traces choices and the choice map. I'm worried this will throw an error (however) if not all constraints are visited (e.g. if the previous trace has non-empty set difference with the new namespace).

I'm also dispatching on process! using the diff types. I'm not sure if this current implementation is correct - mostly because I'm not supported Diffed dispatch yet (and I suspect I have to).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also dispatching on process! using the diff types. I'm not sure if this current implementation is correct - mostly because I'm not supported Diffed dispatch yet (and I suspect I have to).

@femtomc The Diffed types are only used for boxing Julia values for diff propagation of Julia code via operator overloading. GFI implementers only need to worry about Diff values. So what you have is right.

index::Int,
index_argdiff::UnknownChange, # TODO: Diffed wrapper?
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T, DV}

# Generate new trace.
merged = merge(get_choices(state.prev_trace), choices)
branch_fn = getfield(gen_fn.mix, index)
new_trace, weight = generate(branch_fn, args, merged)
retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch))
weight -= get_score(state.prev_trace)

# Set state.
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.updated_retdiff = retdiff
state.discard = discard
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::NoChange, # TODO: Diffed wrapper?
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T}

# Update trace.
new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices)

# Set state.
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.updated_retdiff = retdiff
state.discard = discard
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state)

function update(trace::SwitchTrace{T},
args::Tuple,
argdiffs::Tuple,
choices::ChoiceMap) where T
gen_fn = trace.gen_fn
index, index_argdiff = args[1], argdiffs[1]
state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace)
process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], choices, state)
return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.updated_retdiff, state.discard
end
Loading