55 )
66
77Run enumerative inference over a `model`, given `observations` and an iterator over
8- choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
9- to be iterated over. Return an array of traces and associated log-weights with the
10- same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
11- corresponds to the log probablity of the volume of sample space that the trace
12- represents. Also return an estimate of the log marginal likelihood of the
13- observations (`lml_est`).
8+ choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
9+ choices to be iterated over. An iterator over a grid of choice maps and log-volumes
10+ can be constructed with [`choice_vol_grid`](@ref).
11+
12+ Return an array of traces and associated log-weights with the same shape as
13+ `choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
14+ to the log probability of the volume of sample space that the trace represents.
15+ Also return an estimate of the log marginal likelihood of the observations (`lml_est`).
1416
1517All addresses in the `observations` choice map must be sampled by the model when
1618given the model arguments. The same constraint applies to choice maps enumerated
1719over by `choice_vol_iter`, which must also avoid sharing addresses with the
18- `observations`. An iterator over a grid of choice maps and volumes can be constructed
19- with [`choice_vol_grid`](@ref).
20+ `observations`.
2021"""
2122function enumerative_inference (
2223 model:: GenerativeFunction{T,U} , model_args:: Tuple ,
@@ -33,10 +34,10 @@ function enumerative_inference(
3334 traces = Vector {U} (undef, length (choice_vol_iter))
3435 log_weights = Vector {Float64} (undef, length (choice_vol_iter))
3536 end
36- for (i, (choices, vol )) in enumerate (choice_vol_iter)
37+ for (i, (choices, log_vol )) in enumerate (choice_vol_iter)
3738 constraints = merge (observations, choices)
3839 (traces[i], log_weight) = generate (model, model_args, constraints)
39- log_weights[i] = log_weight + log (vol)
40+ log_weights[i] = log_weight + log_vol
4041 end
4142 log_total_weight = logsumexp (log_weights)
4243 log_normalized_weights = log_weights .- log_total_weight
4748 choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
4849
4950Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
50- over tuples of the form `(choices::ChoiceMap, vol ::Real)` via grid enumeration.
51+ over tuples of the form `(choices::ChoiceMap, log_vol ::Real)` via grid enumeration.
5152
5253Each `addr` is an address of a random choice, and `vals` are the corresponding
5354values or intervals to enumerate over. The (optional) `support` denotes whether
@@ -63,9 +64,9 @@ Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
6364The `anchor` keyword argument controls which point in each interval is used as
6465the anchor (`:left`, `:right`, or `:midpoint`).
6566
66- The volume `vol ` associated with each set of `choices` in the grid is given by
67- the product of the volumes of each continuous region used to construct those
68- choices. If all addresses enumerated over are `:discrete`, then `vol = 1 .0`.
67+ The log- volume `log_vol ` associated with each set of `choices` in the grid is given
68+ by the log- product of the volumes of each continuous region used to construct those
69+ choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0 .0`.
6970"""
7071function choice_vol_grid (grid_specs:: Tuple... ; anchor:: Symbol = :midpoint )
7172 val_iter = (expand_grid_spec_to_values (spec... ; anchor= anchor)
@@ -74,7 +75,7 @@ function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
7475 vol_iter = (expand_grid_spec_to_volumes (spec... ) for spec in grid_specs)
7576 vol_iter = Iterators. product (vol_iter... )
7677 choice_vol_iter = Iterators. map (zip (val_iter, vol_iter)) do (vals, vols)
77- return (choicemap (vals... ), prod (vols))
78+ return (choicemap (vals... ), sum (vols))
7879 end
7980 return choice_vol_iter
8081end
@@ -116,13 +117,13 @@ function expand_grid_spec_to_volumes(
116117 addr, vals, support:: Symbol = :discrete , dims:: Val{N} = Val (1 )
117118) where {N}
118119 if support == :discrete
119- return ones (length (vals))
120+ return zeros (length (vals))
120121 elseif support == :continuous && N == 1
121- return diff (vals)
122+ return log .( diff (vals) )
122123 elseif support == :continuous && N > 1
123124 @assert length (vals) == N " Dimension mismatch between `vals` and `dims`"
124- diffs = Iterators. product ((diff (vs) for vs in vals). .. )
125- return (prod (ds) for ds in diffs)
125+ diffs = Iterators. product ((log .( diff (vs) ) for vs in vals). .. )
126+ return (sum (ds) for ds in diffs)
126127 else
127128 error (" Support must be :discrete or :continuous" )
128129 end
0 commit comments