Skip to content

Commit 91d798f

Browse files
authored
Merge pull request #545 from probcomp/enumerative_inference
Add enumerative inference to the inference library
2 parents a5fc8e3 + 358d3e4 commit 91d798f

File tree

5 files changed

+290
-0
lines changed

5 files changed

+290
-0
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pages = [
3030
"Custom Generative Functions" => "ref/modeling/custom_gen_fns.md",
3131
],
3232
"Inference Library" => [
33+
"Enumerative Inference" => "ref/inference/enumerative.md",
3334
"Importance Sampling" => "ref/inference/importance.md",
3435
"Markov Chain Monte Carlo" => "ref/inference/mcmc.md",
3536
"Particle Filtering & SMC" => "ref/inference/pf.md",
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Enumerative Inference
2+
3+
Enumerative inference can be used to compute the exact posterior distribution for a generative model
4+
with a finite number of discrete random choices, to compute a grid approximation of a continuous
5+
posterior density, or to perform stratified sampling by enumerating over discrete random choices and sampling
6+
the continuous random choices. This functionality is provided by [`enumerative_inference`](@ref).
7+
8+
```@docs
9+
enumerative_inference
10+
```
11+
12+
To construct a rectangular grid of [choice maps](../core/choice_maps.md) and their associated log-volumes to iterate over, use the [`choice_vol_grid`](@ref) function.
13+
14+
```@docs
15+
choice_vol_grid
16+
```
17+
18+
When the space of possible choice maps is not rectangular (e.g. some addresses only exist depending on the values of other addresses), iterators over choice maps and log-volumes can be also be manually constructed.

src/inference/enumerative.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
(traces, log_norm_weights, lml_est) = enumerative_inference(
3+
model::GenerativeFunction, model_args::Tuple,
4+
observations::ChoiceMap, choice_vol_iter
5+
)
6+
7+
Run enumerative inference over a `model`, given `observations` and an iterator over
8+
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
9+
choices to be iterated over. An iterator over a grid of choice maps and log-volumes
10+
can be constructed with [`choice_vol_grid`](@ref).
11+
12+
Return an array of traces and associated log-weights with the same shape as
13+
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
14+
to the log probability of the volume of sample space that the trace represents.
15+
Also return an estimate of the log marginal likelihood of the observations (`lml_est`).
16+
17+
All addresses in the `observations` choice map must be sampled by the model when
18+
given the model arguments. The same constraint applies to choice maps enumerated
19+
over by `choice_vol_iter`, which must also avoid sharing addresses with the
20+
`observations`. When the choice maps in `choice_vol_iter` do not fully specify
21+
the values of all unobserved random choices, the unspecified choices are sampled
22+
from the internal proposal distribution of the model.
23+
"""
24+
function enumerative_inference(
25+
model::GenerativeFunction{T,U}, model_args::Tuple,
26+
observations::ChoiceMap, choice_vol_iter::I
27+
) where {T,U,I}
28+
if Base.IteratorSize(I) isa Base.HasShape
29+
traces = Array{U}(undef, size(choice_vol_iter))
30+
log_weights = Array{Float64}(undef, size(choice_vol_iter))
31+
elseif Base.IteratorSize(I) isa Base.HasLength
32+
traces = Vector{U}(undef, length(choice_vol_iter))
33+
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
34+
else
35+
choice_vol_iter = collect(choice_vol_iter)
36+
traces = Vector{U}(undef, length(choice_vol_iter))
37+
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
38+
end
39+
for (i, (choices, log_vol)) in enumerate(choice_vol_iter)
40+
constraints = merge(observations, choices)
41+
(traces[i], log_weight) = generate(model, model_args, constraints)
42+
log_weights[i] = log_weight + log_vol
43+
end
44+
log_total_weight = logsumexp(log_weights)
45+
log_normalized_weights = log_weights .- log_total_weight
46+
return (traces, log_normalized_weights, log_total_weight)
47+
end
48+
49+
"""
50+
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
51+
52+
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
53+
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration.
54+
55+
Each `addr` is an address of a random choice, and `vals` are the corresponding
56+
values or intervals to enumerate over. The (optional) `support` denotes whether
57+
each random choice is `:discrete` (default) or `:continuous`. This controls how
58+
the grid is constructed:
59+
- `support = :discrete`: The grid iterates over each value in `vals`.
60+
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the
61+
anchors of 1D intervals whose endpoints are given by `vals`.
62+
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
63+
over the anchors of multi-dimensional regions defined `vals`, which is a tuple
64+
of interval endpoints for each dimension.
65+
Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
66+
The `anchor` keyword argument controls which point in each interval is used as
67+
the anchor (`:left`, `:right`, or `:midpoint`).
68+
69+
The log-volume `log_vol` associated with each set of `choices` in the grid is given
70+
by the log-product of the volumes of each continuous region used to construct those
71+
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`.
72+
"""
73+
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
74+
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
75+
for spec in grid_specs)
76+
val_iter = Iterators.product(val_iter...)
77+
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
78+
vol_iter = Iterators.product(vol_iter...)
79+
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
80+
return (choicemap(vals...), sum(vols))
81+
end
82+
return choice_vol_iter
83+
end
84+
85+
function expand_grid_spec_to_values(
86+
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1);
87+
anchor::Symbol = :midpoint
88+
) where {N}
89+
if support == :discrete
90+
return ((addr, v) for v in vals)
91+
elseif support == :continuous && N == 1
92+
if anchor == :left
93+
vals = @view(vals[begin:end-1])
94+
elseif anchor == :right
95+
vals = @view(vals[begin+1:end])
96+
else
97+
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2)
98+
end
99+
return ((addr, v) for v in vals)
100+
elseif support == :continuous && N > 1
101+
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
102+
vals = map(vals) do vs
103+
if anchor == :left
104+
vs = @view(vs[begin:end-1])
105+
elseif anchor == :right
106+
vs = @view(vs[begin+1:end])
107+
else
108+
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2)
109+
end
110+
return vs
111+
end
112+
return ((addr, collect(v)) for v in Iterators.product(vals...))
113+
else
114+
error("Support must be :discrete or :continuous")
115+
end
116+
end
117+
118+
function expand_grid_spec_to_volumes(
119+
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
120+
) where {N}
121+
if support == :discrete
122+
return zeros(length(vals))
123+
elseif support == :continuous && N == 1
124+
return log.(diff(vals))
125+
elseif support == :continuous && N > 1
126+
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
127+
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...)
128+
return (sum(ds) for ds in diffs)
129+
else
130+
error("Support must be :discrete or :continuous")
131+
end
132+
end
133+
134+
export enumerative_inference, choice_vol_grid

src/inference/inference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include("hmc.jl")
2121
include("mala.jl")
2222
include("elliptical_slice.jl")
2323

24+
include("enumerative.jl")
2425
include("importance.jl")
2526
include("particle_filter.jl")
2627
include("map_optimize.jl")

test/inference/enumerative.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
@testset "enumerative inference" begin
2+
3+
# polynomial regression model
4+
@gen function poly_model(n::Int, xs)
5+
degree ~ uniform_discrete(1, n)
6+
coeffs = zeros(n+1)
7+
for d in 0:n
8+
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1)
9+
end
10+
ys = zeros(length(xs))
11+
for (i, x) in enumerate(xs)
12+
x_powers = x .^ (0:n)
13+
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree)
14+
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1)
15+
end
16+
return ys
17+
end
18+
19+
# synthetic dataset
20+
coeffs = [0.5, 0.1, -0.5]
21+
xs = collect(0.5:0.5:3.0)
22+
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs]
23+
24+
observations = choicemap()
25+
for (i, y) in enumerate(ys)
26+
observations[(:y, i)] = y
27+
end
28+
29+
# test construction of choicemap-volume grid
30+
grid = choice_vol_grid(
31+
(:degree, 1:2),
32+
((:coeff, 0), -1:0.2:1, :continuous),
33+
((:coeff, 1), -1:0.2:1, :continuous),
34+
((:coeff, 2), -1:0.2:1, :continuous),
35+
anchor = :midpoint
36+
)
37+
38+
@test size(grid) == (2, 10, 10, 10)
39+
@test length(grid) == 2000
40+
41+
choices, log_vol = first(grid)
42+
@test choices == choicemap(
43+
(:degree, 1),
44+
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
45+
)
46+
@test log_vol log(0.2^3)
47+
48+
test_choices(n::Int, cs) =
49+
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)
50+
51+
@test all(test_choices(2, choices) for (choices, _) in grid)
52+
@test all(log_vol log(0.2^3) for (_, log_vol) in grid)
53+
54+
# run enumerative inference over grid
55+
traces, log_norm_weights, lml_est =
56+
enumerative_inference(poly_model, (2, xs), observations, grid)
57+
58+
@test size(traces) == (2, 10, 10, 10)
59+
@test length(traces) == 2000
60+
@test all(test_choices(2, tr) for tr in traces)
61+
62+
# test that log-weights are as expected
63+
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces]
64+
lml_expected = logsumexp(log_joint_weights)
65+
@test lml_est lml_expected
66+
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))
67+
68+
# test that polynomial is most likely quadratic
69+
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4))
70+
@test argmax(vec(degree_probs)) == 2
71+
72+
# test that MAP trace recovers the original coefficients
73+
map_trace_idx = argmax(log_norm_weights)
74+
map_trace = traces[map_trace_idx]
75+
@test map_trace[:degree] == 2
76+
@test map_trace[(:coeff, 0)] == 0.5
77+
@test map_trace[(:coeff, 1)] == 0.1
78+
@test map_trace[(:coeff, 2)] == -0.5
79+
80+
# 2D mixture of normals
81+
@gen function mixture_model()
82+
sign ~ bernoulli(0.5)
83+
mu = sign ? fill(0.5, 2) : fill(-0.5, 2)
84+
z ~ broadcasted_normal(mu, ones(2))
85+
end
86+
87+
# test construction of grid with 2D random variable
88+
grid = choice_vol_grid(
89+
(:sign, [false, true]),
90+
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)),
91+
anchor = :left
92+
)
93+
94+
@test size(grid) == (2, 40, 40)
95+
@test length(grid) == 3200
96+
97+
choices, log_vol = first(grid)
98+
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0]))
99+
@test log_vol log(0.1^2)
100+
101+
@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid)
102+
@test all(log_vol log(0.1^2) for (_, log_vol) in grid)
103+
104+
# run enumerative inference over grid
105+
traces, log_norm_weights, lml_est =
106+
enumerative_inference(mixture_model, (), choicemap(), grid)
107+
108+
@test size(traces) == (2, 40, 40)
109+
@test length(traces) == 3200
110+
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces)
111+
112+
# test that log-weights are as expected
113+
function expected_logpdf(tr)
114+
x, y = tr[:z]
115+
mu = tr[:sign] ? 0.5 : -0.5
116+
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0)
117+
end
118+
119+
log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces]
120+
lml_expected = logsumexp(log_joint_weights)
121+
@test lml_est lml_expected
122+
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))
123+
124+
# test that maximal log-weights are at modes
125+
max_log_weight = maximum(log_norm_weights)
126+
max_idxs = findall(log_norm_weights .== max_log_weight)
127+
128+
max_trace_1 = traces[max_idxs[1]]
129+
@test max_trace_1[:sign] == false
130+
@test max_trace_1[:z] == [-0.5, -0.5]
131+
132+
max_trace_2 = traces[max_idxs[2]]
133+
@test max_trace_2[:sign] == true
134+
@test max_trace_2[:z] == [0.5, 0.5]
135+
136+
end

0 commit comments

Comments
 (0)