@@ -16,53 +16,79 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
1616 choice_submap_iterator = get_submaps_shallow (choices)
1717 choice_value_iterator = get_values_shallow (choices)
1818 new_choices = DynamicChoiceMap ()
19- for (key, value) in prev_choice_value_iterator
20- key in keys (choice_value_iterator) && continue
21- set_value! (new_choices, key, value)
19+
20+ # Add (address, value) to new_choices from prev_choices if address does not occur in choices.
21+ for (address, value) in prev_choice_value_iterator
22+ address in keys (choice_value_iterator) && continue
23+ set_value! (new_choices, address, value)
2224 end
23- for (key, node1) in prev_choice_submap_iterator
24- if key in keys (choice_submap_iterator)
25- node2 = get_submap (choices, key)
25+
26+ # Add (address, submap) to new_choices from prev_choices if address does not occur in choices.
27+ # If it does, enter a recursive call to update_recurse_merge.
28+ for (address, node1) in prev_choice_submap_iterator
29+ if address in keys (choice_submap_iterator)
30+ node2 = get_submap (choices, address)
2631 node = update_recurse_merge (node1, node2)
27- set_submap! (new_choices, key , node)
32+ set_submap! (new_choices, address , node)
2833 else
29- set_submap! (new_choices, key , node1)
34+ set_submap! (new_choices, address , node1)
3035 end
3136 end
32- for (key, value) in choice_value_iterator
33- set_value! (new_choices, key, value)
37+
38+ # Add (address, value) from choices to new_choices. This is okay because we've excluded any conflicting addresses from the prev_choices above.
39+ for (address, value) in choice_value_iterator
40+ set_value! (new_choices, address, value)
3441 end
42+
3543 sel, _ = zip (prev_choice_submap_iterator... )
3644 comp = complement (select (sel... ))
37- for (key , node) in get_submaps_shallow (get_selected (choices, comp))
38- set_submap! (new_choices, key , node)
45+ for (address , node) in get_submaps_shallow (get_selected (choices, comp))
46+ set_submap! (new_choices, address , node)
3947 end
4048 return new_choices
4149end
4250
43- function update_discard (prev_trace:: Trace , choices:: ChoiceMap , new_trace:: Trace )
51+ @doc (
52+ """
53+ update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
54+
55+ Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints."
56+ """ , update_recurse_merge)
57+
58+ function update_discard (prev_choices:: ChoiceMap , choices:: ChoiceMap , new_choices:: ChoiceMap )
4459 discard = choicemap ()
45- prev_choices = get_choices (prev_trace)
4660 for (k, v) in get_submaps_shallow (prev_choices)
47- isempty (get_submap (get_choices (new_trace), k)) && continue
48- isempty (get_submap (choices, k)) && continue
49- set_submap! (discard, k, v)
61+ new_submap = get_submap (new_choices, k)
62+ choices_submap = get_submap (choices, k)
63+ sub_discard = update_discard (v, choices_submap, new_submap)
64+ set_submap! (discard, k, sub_discard)
5065 end
5166 for (k, v) in get_values_shallow (prev_choices)
52- has_value (get_choices (new_trace) , k) || continue
53- has_value (choices , k) || continue
54- set_value! (discard, k, v)
67+ if ( ! has_value (new_choices , k) || has_value (choices, k))
68+ set_value! (discard , k, v)
69+ end
5570 end
5671 discard
5772end
5873
74+ @doc (
75+ """
76+ update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
77+
78+ Returns choices from previous trace that:
79+ 1. have an address which does not appear in the new trace.
80+ 2. have an address which does appear in the constraints.
81+ """ , update_discard)
82+
83+ @inline update_discard (prev_trace:: Trace , choices:: ChoiceMap , new_trace:: Trace ) = update_discard (get_choices (prev_trace), choices, get_choices (new_trace))
84+
5985function process! (gen_fn:: Switch{C, N, K, T} ,
60- index:: Int ,
61- index_argdiff:: UnknownChange , # TODO : Diffed wrapper?
62- args:: Tuple ,
63- kernel_argdiffs:: Tuple ,
64- choices:: ChoiceMap ,
65- state:: SwitchUpdateState{T} ) where {C, N, K, T, DV}
86+ index:: Int ,
87+ index_argdiff:: UnknownChange ,
88+ args:: Tuple ,
89+ kernel_argdiffs:: Tuple ,
90+ choices:: ChoiceMap ,
91+ state:: SwitchUpdateState{T} ) where {C, N, K, T, DV}
6692
6793 # Generate new trace.
6894 merged = update_recurse_merge (get_choices (state. prev_trace), choices)
@@ -81,12 +107,12 @@ function process!(gen_fn::Switch{C, N, K, T},
81107end
82108
83109function process! (gen_fn:: Switch{C, N, K, T} ,
84- index:: Int ,
85- index_argdiff:: NoChange , # TODO : Diffed wrapper?
86- args:: Tuple ,
87- kernel_argdiffs:: Tuple ,
88- choices:: ChoiceMap ,
89- state:: SwitchUpdateState{T} ) where {C, N, K, T}
110+ index:: Int ,
111+ index_argdiff:: NoChange , # TODO : Diffed wrapper?
112+ args:: Tuple ,
113+ kernel_argdiffs:: Tuple ,
114+ choices:: ChoiceMap ,
115+ state:: SwitchUpdateState{T} ) where {C, N, K, T}
90116
91117 # Update trace.
92118 new_trace, weight, retdiff, discard = update (getfield (state. prev_trace, :branch ), args, kernel_argdiffs, choices)
104130@inline process! (gen_fn:: Switch{C, N, K, T} , index:: C , index_argdiff:: Diff , args:: Tuple , kernel_argdiffs:: Tuple , choices:: ChoiceMap , state:: SwitchUpdateState{T} ) where {C, N, K, T} = process! (gen_fn, getindex (gen_fn. cases, index), index_argdiff, args, kernel_argdiffs, choices, state)
105131
106132function update (trace:: SwitchTrace{T} ,
107- args:: Tuple ,
108- argdiffs:: Tuple ,
109- choices:: ChoiceMap ) where T
133+ args:: Tuple ,
134+ argdiffs:: Tuple ,
135+ choices:: ChoiceMap ) where T
110136 gen_fn = trace. gen_fn
111137 index, index_argdiff = args[1 ], argdiffs[1 ]
112138 state = SwitchUpdateState {T} (0.0 , 0.0 , 0.0 , trace)
0 commit comments