1+ """
2+ (traces, log_norm_weights, lml_est) = enumerative_inference(
3+ model::GenerativeFunction, model_args::Tuple,
4+ observations::ChoiceMap, choice_vol_iter
5+ )
6+
7+ Run enumerative inference over a `model`, given `observations` and an iterator over
8+ choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
9+ to be iterated over. Return an array of traces and associated log-weights with the
10+ same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
11+ corresponds to the log probablity of the volume of sample space that the trace
12+ represents. Also return an estimate of the log marginal likelihood of the
13+ observations (`lml_est`).
14+
15+ All addresses in the `observations` choice map must be sampled by the model when
16+ given the model arguments. The same constraint applies to choice maps enumerated
17+ over by `choice_vol_iter`, which must also avoid sharing addresses with the
18+ `observations`. An iterator over a grid of choice maps and volumes can be constructed
19+ with [`choice_vol_grid`](@ref).
20+ """
21+ function enumerative_inference (
22+ model:: GenerativeFunction{T,U} , model_args:: Tuple ,
23+ observations:: ChoiceMap , choice_vol_iter:: I
24+ ) where {T,U,I}
25+ if Base. IteratorSize (I) isa Base. HasShape
26+ traces = Array {U} (undef, size (choice_vol_iter))
27+ log_weights = Array {Float64} (undef, size (choice_vol_iter))
28+ elseif Base. IteratorSize (I) isa Base. HasLength
29+ traces = Vector {U} (undef, length (choice_vol_iter))
30+ log_weights = Vector {Float64} (undef, length (choice_vol_iter))
31+ else
32+ choice_vol_iter = collect (choice_vol_iter)
33+ traces = Vector {U} (undef, length (choice_vol_iter))
34+ log_weights = Vector {Float64} (undef, length (choice_vol_iter))
35+ end
36+ for (i, (choices, vol)) in enumerate (choice_vol_iter)
37+ constraints = merge (observations, choices)
38+ (traces[i], log_weight) = generate (model, model_args, constraints)
39+ log_weights[i] = log_weight + log (vol)
40+ end
41+ log_total_weight = logsumexp (log_weights)
42+ log_normalized_weights = log_weights .- log_total_weight
43+ return (traces, log_normalized_weights, log_total_weight)
44+ end
45+
46+ """
47+ choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
48+
49+ Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
50+ over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration.
51+
52+ Each `addr` is an address of a random choice, and `vals` are the corresponding
53+ values or intervals to enumerate over. The (optional) `support` denotes whether
54+ each random choice is `:discrete` (default) or `:continuous`. This controls how
55+ the grid is constructed:
56+ - `support = :discrete`: The grid iterates over each value in `vals`.
57+ - `support = :continuous` and `dims == Val(1)`: The grid iterates over the
58+ anchors of 1D intervals whose endpoints are given by `vals`.
59+ - `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
60+ over the anchors of multi-dimensional regions defined `vals`, which is a tuple
61+ of interval endpoints for each dimension.
62+ Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
63+ The `anchor` keyword argument controls which point in each interval is used as
64+ the anchor (`:left`, `:right`, or `:midpoint`).
65+
66+ The volume `vol` associated with each set of `choices` in the grid is given by
67+ the product of the volumes of each continuous region used to construct those
68+ choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`.
69+ """
70+ function choice_vol_grid (grid_specs:: Tuple... ; anchor:: Symbol = :midpoint )
71+ val_iter = (expand_grid_spec_to_values (spec... ; anchor= anchor)
72+ for spec in grid_specs)
73+ val_iter = Iterators. product (val_iter... )
74+ vol_iter = (expand_grid_spec_to_volumes (spec... ) for spec in grid_specs)
75+ vol_iter = Iterators. product (vol_iter... )
76+ choice_vol_iter = Iterators. map (zip (val_iter, vol_iter)) do (vals, vols)
77+ return (choicemap (vals... ), prod (vols))
78+ end
79+ return choice_vol_iter
80+ end
81+
82+ function expand_grid_spec_to_values (
83+ addr, vals, support:: Symbol = :discrete , dims:: Val{N} = Val (1 );
84+ anchor:: Symbol = :midpoint
85+ ) where {N}
86+ if support == :discrete
87+ return ((addr, v) for v in vals)
88+ elseif support == :continuous && N == 1
89+ if anchor == :left
90+ vals = @view (vals[begin : end - 1 ])
91+ elseif anchor == :right
92+ vals = @view (vals[begin + 1 : end ])
93+ else
94+ vals = @view (vals[begin : end - 1 ]) .+ (diff (vals) ./ 2 )
95+ end
96+ return ((addr, v) for v in vals)
97+ elseif support == :continuous && N > 1
98+ @assert length (vals) == N " Dimension mismatch between `vals` and `dims`"
99+ vals = map (vals) do vs
100+ if anchor == :left
101+ vs = @view (vs[begin : end - 1 ])
102+ elseif anchor == :right
103+ vs = @view (vs[begin + 1 : end ])
104+ else
105+ vs = @view (vs[begin : end - 1 ]) .+ (diff (vs) ./ 2 )
106+ end
107+ return vs
108+ end
109+ return ((addr, v) for v in Iterators. product (vals... ))
110+ else
111+ error (" Support must be :discrete or :continuous" )
112+ end
113+ end
114+
115+ function expand_grid_spec_to_volumes (
116+ addr, vals, support:: Symbol = :discrete , dims:: Val{N} = Val (1 )
117+ ) where {N}
118+ if support == :discrete
119+ return ones (length (vals))
120+ elseif support == :continuous && N == 1
121+ return diff (vals)
122+ elseif support == :continuous && N > 1
123+ @assert length (vals) == N " Dimension mismatch between `vals` and `dims`"
124+ diffs = Iterators. product ((diff (vs) for vs in vals). .. )
125+ return (prod (ds) for ds in diffs)
126+ else
127+ error (" Support must be :discrete or :continuous" )
128+ end
129+ end
130+
131+ export enumerative_inference, choice_vol_grid
0 commit comments