Skip to content

Commit 43c7274

Browse files
committed
Addressed review comments. Added docstrings where necessary. Corrected update_discard. Added test to test the discard functionality in a hierarchical model example.
1 parent 0465965 commit 43c7274

File tree

3 files changed

+101
-59
lines changed

3 files changed

+101
-59
lines changed

src/modeling_library/switch/regenerate.jl

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,15 @@ mutable struct SwitchRegenerateState{T}
99
SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
1010
end
1111

12-
@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::EmptySelection) = prev_choices
13-
@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::AllSelection) = choicemap()
14-
function regenerate_recurse_merge(prev_choices::ChoiceMap, selection::Selection)
15-
prev_choice_value_iterator = get_values_shallow(prev_choices)
16-
prev_choice_submap_iterator = get_submaps_shallow(prev_choices)
17-
new_choices = choicemap()
18-
for (key, value) in prev_choice_value_iterator
19-
in(key, selection) && continue
20-
set_value!(new_choices, key, value)
21-
end
22-
for (key, node1) in prev_choice_submap_iterator
23-
if in(key, selection)
24-
subsel = getindex(selection, key)
25-
node = regenerate_recurse_merge(node1, subsel)
26-
set_submap!(new_choices, key, node)
27-
else
28-
set_submap!(new_choices, key, node1)
29-
end
30-
end
31-
return new_choices
32-
end
33-
3412
function process!(gen_fn::Switch{C, N, K, T},
3513
index::Int,
36-
index_argdiff::UnknownChange,
14+
index_argdiff::Diff,
3715
args::Tuple,
3816
kernel_argdiffs::Tuple,
3917
selection::Selection,
4018
state::SwitchRegenerateState{T}) where {C, N, K, T}
4119
branch_fn = getfield(gen_fn.branches, index)
42-
merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection)
20+
merged = get_selected(get_choices(state.prev_trace), complement(selection))
4321
new_trace, weight = generate(branch_fn, args, merged)
4422
retdiff = UnknownChange()
4523
weight -= project(state.prev_trace, complement(selection))

src/modeling_library/switch/update.jl

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,79 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
1616
choice_submap_iterator = get_submaps_shallow(choices)
1717
choice_value_iterator = get_values_shallow(choices)
1818
new_choices = DynamicChoiceMap()
19-
for (key, value) in prev_choice_value_iterator
20-
key in keys(choice_value_iterator) && continue
21-
set_value!(new_choices, key, value)
19+
20+
# Add (address, value) to new_choices from prev_choices if address does not occur in choices.
21+
for (address, value) in prev_choice_value_iterator
22+
address in keys(choice_value_iterator) && continue
23+
set_value!(new_choices, address, value)
2224
end
23-
for (key, node1) in prev_choice_submap_iterator
24-
if key in keys(choice_submap_iterator)
25-
node2 = get_submap(choices, key)
25+
26+
# Add (address, submap) to new_choices from prev_choices if address does not occur in choices.
27+
# If it does, enter a recursive call to update_recurse_merge.
28+
for (address, node1) in prev_choice_submap_iterator
29+
if address in keys(choice_submap_iterator)
30+
node2 = get_submap(choices, address)
2631
node = update_recurse_merge(node1, node2)
27-
set_submap!(new_choices, key, node)
32+
set_submap!(new_choices, address, node)
2833
else
29-
set_submap!(new_choices, key, node1)
34+
set_submap!(new_choices, address, node1)
3035
end
3136
end
32-
for (key, value) in choice_value_iterator
33-
set_value!(new_choices, key, value)
37+
38+
# Add (address, value) from choices to new_choices. This is okay because we've excluded any conflicting addresses from the prev_choices above.
39+
for (address, value) in choice_value_iterator
40+
set_value!(new_choices, address, value)
3441
end
42+
3543
sel, _ = zip(prev_choice_submap_iterator...)
3644
comp = complement(select(sel...))
37-
for (key, node) in get_submaps_shallow(get_selected(choices, comp))
38-
set_submap!(new_choices, key, node)
45+
for (address, node) in get_submaps_shallow(get_selected(choices, comp))
46+
set_submap!(new_choices, address, node)
3947
end
4048
return new_choices
4149
end
4250

43-
function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace)
51+
@doc(
52+
"""
53+
update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
54+
55+
Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints."
56+
""", update_recurse_merge)
57+
58+
function update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
4459
discard = choicemap()
45-
prev_choices = get_choices(prev_trace)
4660
for (k, v) in get_submaps_shallow(prev_choices)
47-
isempty(get_submap(get_choices(new_trace), k)) && continue
48-
isempty(get_submap(choices, k)) && continue
49-
set_submap!(discard, k, v)
61+
new_submap = get_submap(new_choices, k)
62+
choices_submap = get_submap(choices, k)
63+
sub_discard = update_discard(v, choices_submap, new_submap)
64+
set_submap!(discard, k, sub_discard)
5065
end
5166
for (k, v) in get_values_shallow(prev_choices)
52-
has_value(get_choices(new_trace), k) || continue
53-
has_value(choices, k) || continue
54-
set_value!(discard, k, v)
67+
if (!has_value(new_choices, k) || has_value(choices, k))
68+
set_value!(discard, k, v)
69+
end
5570
end
5671
discard
5772
end
5873

74+
@doc(
75+
"""
76+
update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
77+
78+
Returns choices from previous trace that:
79+
1. have an address which does not appear in the new trace.
80+
2. have an address which does appear in the constraints.
81+
""", update_discard)
82+
83+
@inline update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) = update_discard(get_choices(prev_trace), choices, get_choices(new_trace))
84+
5985
function process!(gen_fn::Switch{C, N, K, T},
60-
index::Int,
61-
index_argdiff::UnknownChange, # TODO: Diffed wrapper?
62-
args::Tuple,
63-
kernel_argdiffs::Tuple,
64-
choices::ChoiceMap,
65-
state::SwitchUpdateState{T}) where {C, N, K, T, DV}
86+
index::Int,
87+
index_argdiff::UnknownChange,
88+
args::Tuple,
89+
kernel_argdiffs::Tuple,
90+
choices::ChoiceMap,
91+
state::SwitchUpdateState{T}) where {C, N, K, T, DV}
6692

6793
# Generate new trace.
6894
merged = update_recurse_merge(get_choices(state.prev_trace), choices)
@@ -81,12 +107,12 @@ function process!(gen_fn::Switch{C, N, K, T},
81107
end
82108

83109
function process!(gen_fn::Switch{C, N, K, T},
84-
index::Int,
85-
index_argdiff::NoChange, # TODO: Diffed wrapper?
86-
args::Tuple,
87-
kernel_argdiffs::Tuple,
88-
choices::ChoiceMap,
89-
state::SwitchUpdateState{T}) where {C, N, K, T}
110+
index::Int,
111+
index_argdiff::NoChange, # TODO: Diffed wrapper?
112+
args::Tuple,
113+
kernel_argdiffs::Tuple,
114+
choices::ChoiceMap,
115+
state::SwitchUpdateState{T}) where {C, N, K, T}
90116

91117
# Update trace.
92118
new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices)
@@ -104,9 +130,9 @@ end
104130
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state)
105131

106132
function update(trace::SwitchTrace{T},
107-
args::Tuple,
108-
argdiffs::Tuple,
109-
choices::ChoiceMap) where T
133+
args::Tuple,
134+
argdiffs::Tuple,
135+
choices::ChoiceMap) where T
110136
gen_fn = trace.gen_fn
111137
index, index_argdiff = args[1], argdiffs[1]
112138
state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace)

test/modeling_library/switch.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,4 +299,42 @@
299299
@test isapprox(gradients[:x => :z], expected_choice_grad[1])
300300
end
301301
end
302+
303+
# ------------ (More complex) hierarchy to test discard ------------ #
304+
305+
# Model chunk.
306+
@gen (grad) function bang3((grad)(x::Float64), (grad)(y::Float64))
307+
std::Float64 = 3.0
308+
z = @trace(normal(x + y, std), :z)
309+
q = @trace(bang2(z, y), :q)
310+
return z
311+
end
312+
@gen (grad) function fuzz3((grad)(x::Float64), (grad)(y::Float64))
313+
std::Float64 = 3.0
314+
z = @trace(normal(x + 2 * y, std), :z)
315+
m = @trace(normal(x + 3 * y, std), :m)
316+
q = @trace(bang3(z, y), :q)
317+
return z
318+
end
319+
sc3 = Switch(bang3, fuzz3)
320+
@gen (grad) function bam3(s::Int)
321+
x ~ sc3(s, 5.0, 3.0)
322+
return x
323+
end
324+
# ----.
325+
326+
@testset "update" begin
327+
tr = simulate(bam3, (2, ))
328+
old_sc = get_score(tr)
329+
chm = choicemap((:x => :z, 5.0))
330+
future_discarded = tr[:x => :z]
331+
new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm)
332+
@test discard[:x => :z] == future_discarded
333+
@test isapprox(old_sc, get_score(new_tr) - w)
334+
chm = choicemap((:x => :z, 10.0))
335+
future_discarded = tr[:x => :q => :q => :z]
336+
new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm)
337+
@test discard[:x => :q => :q => :z] == future_discarded
338+
@test isapprox(old_sc, get_score(new_tr) - w)
339+
end
302340
end

0 commit comments

Comments
 (0)