Skip to content

Commit 38e6571

Browse files
authored
Merge pull request #334 from femtomc/20201116_mrb_switch_combinator
(Ready for review): Switch combinator
2 parents 52b94a7 + 43c7274 commit 38e6571

File tree

13 files changed

+782
-0
lines changed

13 files changed

+782
-0
lines changed

docs/src/ref/combinators.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,41 @@ TODO: document me
119119
<img src="../../images/recurse_combinator.png" alt="schematic of recurse combinatokr" width="70%"/>
120120
</div>
121121
```
122+
## Switch combinator
122123

124+
```@docs
125+
Switch
126+
```
127+
128+
In the schematic below, the kernel is denoted `S` and accepts an integer index `k`.
129+
130+
Consider the following constructions:
131+
132+
```julia
133+
@gen function bang((grad)(x::Float64), (grad)(y::Float64))
134+
std::Float64 = 3.0
135+
z = @trace(normal(x + y, std), :z)
136+
return z
137+
end
138+
139+
@gen function fuzz((grad)(x::Float64), (grad)(y::Float64))
140+
std::Float64 = 3.0
141+
z = @trace(normal(x + 2 * y, std), :z)
142+
return z
143+
end
144+
145+
sc = Switch(bang, fuzz)
146+
```
147+
148+
This creates a new generative function `sc`. We can then obtain the trace of `sc`:
149+
150+
```julia
151+
(trace, _) = simulate(sc, (2, 5.0, 3.0))
152+
```
153+
154+
The resulting trace contains the subtrace from the branch with index `2` - in this case, a call to `fuzz`:
155+
156+
```
157+
158+
└── :z : 13.552870875213735
159+
```

src/modeling_library/cond.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# ------------ Switch trace ------------ #
2+
3+
struct SwitchTrace{T} <: Trace
4+
gen_fn::GenerativeFunction{T}
5+
index::Int
6+
branch::Trace
7+
retval::T
8+
args::Tuple
9+
score::Float64
10+
noise::Float64
11+
end
12+
13+
@inline get_choices(tr::SwitchTrace) = get_choices(tr.branch)
14+
@inline get_retval(tr::SwitchTrace) = tr.retval
15+
@inline get_args(tr::SwitchTrace) = tr.args
16+
@inline get_score(tr::SwitchTrace) = tr.score
17+
@inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn
18+
@inline Base.getindex(tr::SwitchTrace, addr) = Base.getindex(tr.branch, addr)
19+
@inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection)
20+
@inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise

src/modeling_library/modeling_library.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,16 @@ include("dist_dsl/dist_dsl.jl")
6666
# code shared by vector-shaped combinators
6767
include("vector.jl")
6868

69+
# traces for with prob/switch combinator
70+
include("cond.jl")
71+
6972
# built-in generative function combinators
7073
include("choice_at/choice_at.jl")
7174
include("call_at/call_at.jl")
7275
include("map/map.jl")
7376
include("unfold/unfold.jl")
7477
include("recurse/recurse.jl")
78+
include("switch/switch.jl")
7579

7680
#############################################################
7781
# abstractions for constructing custom generative functions #
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
mutable struct SwitchAssessState{T}
2+
weight::Float64
3+
retval::T
4+
SwitchAssessState{T}(weight::Float64) where T = new{T}(weight)
5+
end
6+
7+
function process!(gen_fn::Switch{C, N, K, T},
8+
index::Int,
9+
args::Tuple,
10+
choices::ChoiceMap,
11+
state::SwitchAssessState{T}) where {C, N, K, T}
12+
(weight, retval) = assess(getindex(gen_fn.branches, index), args, choices)
13+
state.weight = weight
14+
state.retval = retval
15+
end
16+
17+
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)
18+
19+
function assess(gen_fn::Switch{C, N, K, T},
20+
args::Tuple,
21+
choices::ChoiceMap) where {C, N, K, T}
22+
index = args[1]
23+
state = SwitchAssessState{T}(0.0)
24+
process!(gen_fn, index, args[2 : end], choices, state)
25+
return state.weight, state.retval
26+
end
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad)
2+
@inline accumulate_param_gradients!(trace::SwitchTrace{T}, retval_grad, scale_factor = 1.) where {T} = accumulate_param_gradients!(getfield(trace, :branch), retval_grad, scale_factor)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
mutable struct SwitchGenerateState{T}
2+
score::Float64
3+
noise::Float64
4+
weight::Float64
5+
index::Int
6+
subtrace::Trace
7+
retval::T
8+
SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight)
9+
end
10+
11+
function process!(gen_fn::Switch{C, N, K, T},
12+
index::Int,
13+
args::Tuple,
14+
choices::ChoiceMap,
15+
state::SwitchGenerateState{T}) where {C, N, K, T}
16+
17+
(subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices)
18+
state.index = index
19+
state.subtrace = subtrace
20+
state.weight += weight
21+
state.retval = get_retval(subtrace)
22+
end
23+
24+
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)
25+
26+
function generate(gen_fn::Switch{C, N, K, T},
27+
args::Tuple,
28+
choices::ChoiceMap) where {C, N, K, T}
29+
30+
index = args[1]
31+
state = SwitchGenerateState{T}(0.0, 0.0, 0.0)
32+
process!(gen_fn, index, args[2 : end], choices, state)
33+
return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight
34+
end
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
mutable struct SwitchProposeState{T}
2+
choices::DynamicChoiceMap
3+
weight::Float64
4+
retval::T
5+
SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight)
6+
end
7+
8+
function process!(gen_fn::Switch{C, N, K, T},
9+
index::Int,
10+
args::Tuple,
11+
state::SwitchProposeState{T}) where {C, N, K, T}
12+
13+
(submap, weight, retval) = propose(getindex(gen_fn.branches, index), args)
14+
state.choices = submap
15+
state.weight += weight
16+
state.retval = retval
17+
end
18+
19+
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)
20+
21+
function propose(gen_fn::Switch{C, N, K, T},
22+
args::Tuple) where {C, N, K, T}
23+
24+
index = args[1]
25+
choices = choicemap()
26+
state = SwitchProposeState{T}(choices, 0.0)
27+
process!(gen_fn, index, args[2:end], state)
28+
return state.choices, state.weight, state.retval
29+
end
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
mutable struct SwitchRegenerateState{T}
2+
weight::Float64
3+
score::Float64
4+
noise::Float64
5+
prev_trace::Trace
6+
trace::Trace
7+
index::Int
8+
retdiff::Diff
9+
SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
10+
end
11+
12+
function process!(gen_fn::Switch{C, N, K, T},
13+
index::Int,
14+
index_argdiff::Diff,
15+
args::Tuple,
16+
kernel_argdiffs::Tuple,
17+
selection::Selection,
18+
state::SwitchRegenerateState{T}) where {C, N, K, T}
19+
branch_fn = getfield(gen_fn.branches, index)
20+
merged = get_selected(get_choices(state.prev_trace), complement(selection))
21+
new_trace, weight = generate(branch_fn, args, merged)
22+
retdiff = UnknownChange()
23+
weight -= project(state.prev_trace, complement(selection))
24+
weight += (project(new_trace, selection) - project(state.prev_trace, selection))
25+
state.index = index
26+
state.weight = weight
27+
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
28+
state.score = get_score(new_trace)
29+
state.trace = new_trace
30+
state.retdiff = retdiff
31+
end
32+
33+
function process!(gen_fn::Switch{C, N, K, T},
34+
index::Int,
35+
index_argdiff::NoChange,
36+
args::Tuple,
37+
kernel_argdiffs::Tuple,
38+
selection::Selection,
39+
state::SwitchRegenerateState{T}) where {C, N, K, T}
40+
new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection)
41+
state.index = index
42+
state.weight = weight
43+
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
44+
state.score = get_score(new_trace)
45+
state.trace = new_trace
46+
state.retdiff = retdiff
47+
end
48+
49+
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state)
50+
51+
function regenerate(trace::SwitchTrace{T},
52+
args::Tuple,
53+
argdiffs::Tuple,
54+
selection::Selection) where T
55+
gen_fn = trace.gen_fn
56+
index, index_argdiff = args[1], argdiffs[1]
57+
state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace)
58+
process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state)
59+
return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff
60+
end
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
mutable struct SwitchSimulateState{T}
2+
score::Float64
3+
noise::Float64
4+
index::Int
5+
subtrace::Trace
6+
retval::T
7+
SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise)
8+
end
9+
10+
function process!(gen_fn::Switch{C, N, K, T},
11+
index::Int,
12+
args::Tuple,
13+
state::SwitchSimulateState{T}) where {C, N, K, T}
14+
local retval::T
15+
subtrace = simulate(getindex(gen_fn.branches, index), args)
16+
state.index = index
17+
state.noise += project(subtrace, EmptySelection())
18+
state.subtrace = subtrace
19+
state.score += get_score(subtrace)
20+
state.retval = get_retval(subtrace)
21+
end
22+
23+
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)
24+
25+
function simulate(gen_fn::Switch{C, N, K, T},
26+
args::Tuple) where {C, N, K, T}
27+
28+
index = args[1]
29+
state = SwitchSimulateState{T}(0.0, 0.0)
30+
process!(gen_fn, index, args[2 : end], state)
31+
return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise)
32+
end
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace}
2+
branches::NTuple{N, GenerativeFunction{T}}
3+
cases::Dict{C, Int}
4+
function Switch(gen_fns::GenerativeFunction...)
5+
@assert !isempty(gen_fns)
6+
rettype = get_return_type(getindex(gen_fns, 1))
7+
new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}())
8+
end
9+
function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C
10+
@assert !isempty(gen_fns)
11+
rettype = get_return_type(getindex(gen_fns, 1))
12+
new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d)
13+
end
14+
end
15+
export Switch
16+
17+
has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as
18+
all(as)
19+
end
20+
accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches)
21+
22+
function (gen_fn::Switch)(index::Int, args...)
23+
(_, _, retval) = propose(gen_fn, (index, args...))
24+
retval
25+
end
26+
27+
function (gen_fn::Switch{C})(index::C, args...) where C
28+
(_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...))
29+
retval
30+
end
31+
32+
include("assess.jl")
33+
include("propose.jl")
34+
include("simulate.jl")
35+
include("generate.jl")
36+
include("update.jl")
37+
include("regenerate.jl")
38+
include("backprop.jl")
39+
40+
@doc(
41+
"""
42+
gen_fn = Switch(gen_fns::GenerativeFunction...)
43+
44+
Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` where the first index indicates which branch to call.
45+
46+
gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T
47+
48+
Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` or an argument tuple of type `Tuple{T, ...}` where the first index either indicates which branch to call, or indicates an index into `d` which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d` like if-else or case statements.
49+
50+
`Switch` is designed to allow for the expression of patterns of if-else control flow. `gen_fns` must satisfy a few requirements:
51+
52+
1. Each `gen_fn` in `gen_fns` must accept the same argument types.
53+
2. Each `gen_fn` in `gen_fns` must return the same return type.
54+
55+
Otherwise, each `gen_fn` can come from different modeling languages, possess different traces, etc.
56+
""", Switch)

0 commit comments

Comments
 (0)