Skip to content

Commit 0465965

Browse files
committed
Renamed mix field of Switch generative function to branches to more accurately reflect the pattern.
1 parent 176b9e9 commit 0465965

File tree

7 files changed

+9
-9
lines changed

7 files changed

+9
-9
lines changed

src/modeling_library/switch/assess.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function process!(gen_fn::Switch{C, N, K, T},
99
args::Tuple,
1010
choices::ChoiceMap,
1111
state::SwitchAssessState{T}) where {C, N, K, T}
12-
(weight, retval) = assess(getindex(gen_fn.mix, index), args, choices)
12+
(weight, retval) = assess(getindex(gen_fn.branches, index), args, choices)
1313
state.weight = weight
1414
state.retval = retval
1515
end

src/modeling_library/switch/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function process!(gen_fn::Switch{C, N, K, T},
1414
choices::ChoiceMap,
1515
state::SwitchGenerateState{T}) where {C, N, K, T}
1616

17-
(subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices)
17+
(subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices)
1818
state.index = index
1919
state.subtrace = subtrace
2020
state.weight += weight

src/modeling_library/switch/propose.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function process!(gen_fn::Switch{C, N, K, T},
1010
args::Tuple,
1111
state::SwitchProposeState{T}) where {C, N, K, T}
1212

13-
(submap, weight, retval) = propose(getindex(gen_fn.mix, index), args)
13+
(submap, weight, retval) = propose(getindex(gen_fn.branches, index), args)
1414
state.choices = submap
1515
state.weight += weight
1616
state.retval = retval

src/modeling_library/switch/regenerate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function process!(gen_fn::Switch{C, N, K, T},
3838
kernel_argdiffs::Tuple,
3939
selection::Selection,
4040
state::SwitchRegenerateState{T}) where {C, N, K, T}
41-
branch_fn = getfield(gen_fn.mix, index)
41+
branch_fn = getfield(gen_fn.branches, index)
4242
merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection)
4343
new_trace, weight = generate(branch_fn, args, merged)
4444
retdiff = UnknownChange()

src/modeling_library/switch/simulate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function process!(gen_fn::Switch{C, N, K, T},
1212
args::Tuple,
1313
state::SwitchSimulateState{T}) where {C, N, K, T}
1414
local retval::T
15-
subtrace = simulate(getindex(gen_fn.mix, index), args)
15+
subtrace = simulate(getindex(gen_fn.branches, index), args)
1616
state.index = index
1717
state.noise += project(subtrace, EmptySelection())
1818
state.subtrace = subtrace

src/modeling_library/switch/switch.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace}
2-
mix::NTuple{N, GenerativeFunction{T}}
2+
branches::NTuple{N, GenerativeFunction{T}}
33
cases::Dict{C, Int}
44
function Switch(gen_fns::GenerativeFunction...)
55
@assert !isempty(gen_fns)
@@ -14,10 +14,10 @@ struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace}
1414
end
1515
export Switch
1616

17-
has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.mix)...)) do as
17+
has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as
1818
all(as)
1919
end
20-
accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.mix)
20+
accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches)
2121

2222
function (gen_fn::Switch)(index::Int, args...)
2323
(_, _, retval) = propose(gen_fn, (index, args...))

src/modeling_library/switch/update.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function process!(gen_fn::Switch{C, N, K, T},
6666

6767
# Generate new trace.
6868
merged = update_recurse_merge(get_choices(state.prev_trace), choices)
69-
branch_fn = getfield(gen_fn.mix, index)
69+
branch_fn = getfield(gen_fn.branches, index)
7070
new_trace, weight = generate(branch_fn, args, merged)
7171
weight -= get_score(state.prev_trace)
7272
state.discard = update_discard(state.prev_trace, choices, new_trace)

0 commit comments

Comments
 (0)