-
Notifications
You must be signed in to change notification settings - Fork 162
(Ready for review): Switch combinator #334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
3e4f695
bd4f830
374a7b0
5872593
9c0a9f2
95baf07
29b7797
3e6e307
7929b86
73618a1
252413f
ac3528e
eaf3327
6d58aac
e413e9c
435493f
32fec4f
bb767e7
a35e2e7
562667e
b74a071
915811d
849d61e
adf73a5
dfe0125
3717d65
cb62fb5
97473d0
176b9e9
0465965
43c7274
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # ------------ WithProbability trace ------------ # | ||
|
|
||
| struct WithProbabilityTrace{T1, T2, Tr} <: Trace | ||
| gen_fn::GenerativeFunction{Union{T1, T2}, Tr} | ||
| p::Float64 | ||
| cond::Bool | ||
| branch::Tr | ||
| retval::Union{T1, T2} | ||
| args::Tuple | ||
| score::Float64 | ||
| noise::Float64 | ||
| end | ||
|
|
||
| @inline function get_choices(tr::WithProbabilityTrace) | ||
| choices = choicemap() | ||
| set_submap!(choices, :branch, get_choices(tr.branch)) | ||
| set_value!(choices, :cond, tr.cond) | ||
| choices | ||
| end | ||
| @inline get_retval(tr::WithProbabilityTrace) = tr.retval | ||
| @inline get_args(tr::WithProbabilityTrace) = tr.args | ||
| @inline get_score(tr::WithProbabilityTrace) = tr.score | ||
| @inline get_gen_fn(tr::WithProbabilityTrace) = tr.gen_fn | ||
|
|
||
| @inline function Base.getindex(tr::WithProbabilityTrace, addr::Pair) | ||
| (first, rest) = addr | ||
| subtr = getfield(trace, first) | ||
| subtrace[rest] | ||
| end | ||
| @inline Base.getindex(tr::WithProbabilityTrace, addr::Symbol) = getfield(trace, addr) | ||
|
|
||
| function project(tr::WithProbabilityTrace, selection::Selection) | ||
| sum(map([:cond, :branch]) do k | ||
| subselection = selection[k] | ||
| project(getindex(tr, k), subselection) | ||
| end) | ||
| end | ||
| project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise | ||
|
|
||
| # ------------ Switch trace ------------ # | ||
|
|
||
| struct SwitchTrace{T} <: Trace | ||
| gen_fn::GenerativeFunction{T} | ||
| index::Int | ||
| branch::Trace | ||
| retval::T | ||
| args::Tuple | ||
| score::Float64 | ||
| noise::Float64 | ||
| end | ||
|
|
||
| @inline get_choices(tr::SwitchTrace) = get_choices(tr.branch) | ||
| @inline get_retval(tr::SwitchTrace) = tr.retval | ||
| @inline get_args(tr::SwitchTrace) = tr.args | ||
| @inline get_score(tr::SwitchTrace) = tr.score | ||
| @inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn | ||
|
|
||
| @inline function Base.getindex(tr::SwitchTrace, addr::Pair) | ||
| (first, rest) = addr | ||
| subtr = getfield(trace, first) | ||
| subtrace[rest] | ||
| end | ||
| @inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) | ||
|
|
||
| @inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) | ||
| @inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| mutable struct SwitchAssessState{T} | ||
| weight::Float64 | ||
| retval::T | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchAssessState{T}) where {C, N, K, T} | ||
| (weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices) | ||
| state.weight += weight | ||
| state.retval = retval | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) | ||
|
|
||
| function assess(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple, | ||
| choices::ChoiceMap) where {C, N, K, T} | ||
| index = args[1] | ||
| state = SwitchAssessState{T}(0.0) | ||
| process!(gen_fn, index, args[2 : end], choices, state) | ||
| return state.weight, state.retval | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| mutable struct SwitchGenerateState{T} | ||
| score::Float64 | ||
| noise::Float64 | ||
| weight::Float64 | ||
| index::Int | ||
| subtrace::Trace | ||
| retval::T | ||
| SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchGenerateState{T}) where {C, N, K, T} | ||
|
|
||
| (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) | ||
| state.index = index | ||
| state.subtrace = subtrace | ||
| state.weight += weight | ||
| state.retval = get_retval(subtrace) | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) | ||
|
|
||
| function generate(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple, | ||
| choices::ChoiceMap) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| state = SwitchGenerateState{T}(0.0, 0.0, 0.0) | ||
| process!(gen_fn, index, args[2 : end], choices, state) | ||
| return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| mutable struct SwitchProposeState{T} | ||
| choices::DynamicChoiceMap | ||
| weight::Float64 | ||
| retval::T | ||
| SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| state::SwitchProposeState{T}) where {C, N, K, T} | ||
|
|
||
| (submap, weight, retval) = propose(getindex(gen_fn.mix, index), args) | ||
| state.choices = submap | ||
| state.weight += weight | ||
| state.retval = retval | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) | ||
|
|
||
| function propose(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| choices = choicemap() | ||
| state = SwitchProposeState{T}(choices, 0.0) | ||
| process!(gen_fn, index, args[2:end], state) | ||
| return state.choices, state.weight, state.retval | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| mutable struct SwitchRegenerateState{T} | ||
| weight::Float64 | ||
| score::Float64 | ||
| noise::Float64 | ||
| prev_trace::Trace | ||
| trace::Trace | ||
| index::Int | ||
| retdiff::Diff | ||
| SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| index_argdiff::UnknownChange, | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| selection::Selection, | ||
| state::SwitchRegenerateState{T}) where {C, N, K, T} | ||
| merged = get_choices(state.prev_trace) | ||
| branch_fn = getfield(gen_fn.mix, index) | ||
| new_trace, weight = generate(branch_fn, args, merged) | ||
| retdiff = UnknownChange() | ||
| weight -= get_score(state.prev_trace) | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.retdiff = retdiff | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| index_argdiff::NoChange, | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| selection::Selection, | ||
| state::SwitchRegenerateState{T}) where {C, N, K, T} | ||
| new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.retdiff = retdiff | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) | ||
|
|
||
| function regenerate(trace::SwitchTrace{T}, | ||
| args::Tuple, | ||
| argdiffs::Tuple, | ||
| selection::Selection) where T | ||
| gen_fn = trace.gen_fn | ||
| index, index_argdiff = args[1], argdiffs[1] | ||
| state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) | ||
| process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) | ||
| return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| mutable struct SwitchSimulateState{T} | ||
| score::Float64 | ||
| noise::Float64 | ||
| index::Int | ||
| subtrace::Trace | ||
| retval::T | ||
| SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| state::SwitchSimulateState{T}) where {C, N, K, T} | ||
| local retval::T | ||
| subtrace = simulate(getindex(gen_fn.mix, index), args) | ||
| state.index = index | ||
| state.noise += project(subtrace, EmptySelection()) | ||
| state.subtrace = subtrace | ||
| state.score += get_score(subtrace) | ||
| state.retval = get_retval(subtrace) | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) | ||
|
|
||
| function simulate(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| state = SwitchSimulateState{T}(0.0, 0.0) | ||
| process!(gen_fn, index, args[2 : end], state) | ||
| SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} | ||
| mix::NTuple{N, GenerativeFunction{T}} | ||
| cases::Dict{C, Int} | ||
| function Switch(gen_fns::GenerativeFunction...) | ||
| @assert !isempty(gen_fns) | ||
| rettype = get_return_type(getindex(gen_fns, 1)) | ||
| new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}()) | ||
| end | ||
| function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C | ||
| @assert !isempty(gen_fns) | ||
| rettype = get_return_type(getindex(gen_fns, 1)) | ||
| new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d) | ||
| end | ||
| end | ||
|
|
||
| export Switch | ||
|
|
||
| has_argument_grads(switch_fn::Switch) = all(has_argument_grads, switch.mix) | ||
| accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch.mix) | ||
|
|
||
| function (gen_fn::Switch)(index::Int, args...) | ||
| (_, _, retval) = propose(gen_fn, (index, args...)) | ||
| retval | ||
| end | ||
|
|
||
| function (gen_fn::Switch{C})(index::C, args...) where C | ||
| (_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...)) | ||
| retval | ||
| end | ||
|
|
||
| include("assess.jl") | ||
| include("propose.jl") | ||
| include("simulate.jl") | ||
| include("generate.jl") | ||
| include("update.jl") | ||
| include("regenerate.jl") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| mutable struct SwitchUpdateState{T} | ||
| weight::Float64 | ||
| score::Float64 | ||
| noise::Float64 | ||
| prev_trace::Trace | ||
| trace::Trace | ||
| index::Int | ||
| discard::ChoiceMap | ||
| updated_retdiff::Diff | ||
| SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @marcoct This seems correct for the "namespace merging" semantics. In I'm also dispatching on
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@femtomc The |
||
| index::Int, | ||
| index_argdiff::UnknownChange, # TODO: Diffed wrapper? | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchUpdateState{T}) where {C, N, K, T, DV} | ||
|
|
||
| # Generate new trace. | ||
| merged = merge(get_choices(state.prev_trace), choices) | ||
| branch_fn = getfield(gen_fn.mix, index) | ||
| new_trace, weight = generate(branch_fn, args, merged) | ||
| retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) | ||
| weight -= get_score(state.prev_trace) | ||
|
|
||
| # Set state. | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.updated_retdiff = retdiff | ||
| state.discard = discard | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| index_argdiff::NoChange, # TODO: Diffed wrapper? | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchUpdateState{T}) where {C, N, K, T} | ||
|
|
||
| # Update trace. | ||
| new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) | ||
|
|
||
| # Set state. | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.updated_retdiff = retdiff | ||
| state.discard = discard | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) | ||
|
|
||
| function update(trace::SwitchTrace{T}, | ||
| args::Tuple, | ||
| argdiffs::Tuple, | ||
| choices::ChoiceMap) where T | ||
| gen_fn = trace.gen_fn | ||
| index, index_argdiff = args[1], argdiffs[1] | ||
| state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) | ||
| process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], choices, state) | ||
| return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.updated_retdiff, state.discard | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to use
Diffinstead ofUnknownChangehere. There might be otherDifftypes that could be passed that are intermediate betweenUnknownChangeandNoChange(e.g. there is anIntDiffalready, which tracks the arithmetic difference between two integers). This method should apply to anything that's not aNoChange.