Skip to content

Commit 7104004

Browse files
committed
Switch to log-volumes for numeric stability.
1 parent 1c46274 commit 7104004

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

src/inference/enumerative.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
)
66
77
Run enumerative inference over a `model`, given `observations` and an iterator over
8-
choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
9-
to be iterated over. Return an array of traces and associated log-weights with the
10-
same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
11-
corresponds to the log probablity of the volume of sample space that the trace
12-
represents. Also return an estimate of the log marginal likelihood of the
13-
observations (`lml_est`).
8+
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
9+
choices to be iterated over. An iterator over a grid of choice maps and log-volumes
10+
can be constructed with [`choice_vol_grid`](@ref).
11+
12+
Return an array of traces and associated log-weights with the same shape as
13+
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
14+
to the log probability of the volume of sample space that the trace represents.
15+
Also return an estimate of the log marginal likelihood of the observations (`lml_est`).
1416
1517
All addresses in the `observations` choice map must be sampled by the model when
1618
given the model arguments. The same constraint applies to choice maps enumerated
1719
over by `choice_vol_iter`, which must also avoid sharing addresses with the
18-
`observations`. An iterator over a grid of choice maps and volumes can be constructed
19-
with [`choice_vol_grid`](@ref).
20+
`observations`.
2021
"""
2122
function enumerative_inference(
2223
model::GenerativeFunction{T,U}, model_args::Tuple,
@@ -33,10 +34,10 @@ function enumerative_inference(
3334
traces = Vector{U}(undef, length(choice_vol_iter))
3435
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
3536
end
36-
for (i, (choices, vol)) in enumerate(choice_vol_iter)
37+
for (i, (choices, log_vol)) in enumerate(choice_vol_iter)
3738
constraints = merge(observations, choices)
3839
(traces[i], log_weight) = generate(model, model_args, constraints)
39-
log_weights[i] = log_weight + log(vol)
40+
log_weights[i] = log_weight + log_vol
4041
end
4142
log_total_weight = logsumexp(log_weights)
4243
log_normalized_weights = log_weights .- log_total_weight
@@ -47,7 +48,7 @@ end
4748
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
4849
4950
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
50-
over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration.
51+
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration.
5152
5253
Each `addr` is an address of a random choice, and `vals` are the corresponding
5354
values or intervals to enumerate over. The (optional) `support` denotes whether
@@ -63,9 +64,9 @@ Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
6364
The `anchor` keyword argument controls which point in each interval is used as
6465
the anchor (`:left`, `:right`, or `:midpoint`).
6566
66-
The volume `vol` associated with each set of `choices` in the grid is given by
67-
the product of the volumes of each continuous region used to construct those
68-
choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`.
67+
The log-volume `log_vol` associated with each set of `choices` in the grid is given
68+
by the log-product of the volumes of each continuous region used to construct those
69+
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`.
6970
"""
7071
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
7172
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
@@ -74,7 +75,7 @@ function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
7475
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
7576
vol_iter = Iterators.product(vol_iter...)
7677
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
77-
return (choicemap(vals...), prod(vols))
78+
return (choicemap(vals...), sum(vols))
7879
end
7980
return choice_vol_iter
8081
end
@@ -116,13 +117,13 @@ function expand_grid_spec_to_volumes(
116117
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
117118
) where {N}
118119
if support == :discrete
119-
return ones(length(vals))
120+
return zeros(length(vals))
120121
elseif support == :continuous && N == 1
121-
return diff(vals)
122+
return log.(diff(vals))
122123
elseif support == :continuous && N > 1
123124
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
124-
diffs = Iterators.product((diff(vs) for vs in vals)...)
125-
return (prod(ds) for ds in diffs)
125+
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...)
126+
return (sum(ds) for ds in diffs)
126127
else
127128
error("Support must be :discrete or :continuous")
128129
end

test/inference/enumerative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@
3838
@test size(grid) == (2, 10, 10, 10)
3939
@test length(grid) == 2000
4040

41-
choices, vol = first(grid)
41+
choices, log_vol = first(grid)
4242
@test choices == choicemap(
4343
(:degree, 1),
4444
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
4545
)
46-
@test vol 0.2 * 0.2 * 0.2
46+
@test log_vol log(0.2^3)
4747

4848
test_choices(n::Int, cs) =
4949
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)
5050

51-
@test all(test_choices(2, choices) for (choices, vol) in grid)
52-
@test all(vol (0.2 * 0.2 * 0.2) for (choices, vol) in grid)
51+
@test all(test_choices(2, choices) for (choices, _) in grid)
52+
@test all(log_vol log(0.2^3) for (_, log_vol) in grid)
5353

5454
# run enumerative inference over grid
5555
traces, log_norm_weights, lml_est =

0 commit comments

Comments
 (0)