Skip to content

Commit 1c46274

Browse files
committed
Add enumerative inference function and test cases.
1 parent a5fc8e3 commit 1c46274

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

src/inference/enumerative.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
(traces, log_norm_weights, lml_est) = enumerative_inference(
3+
model::GenerativeFunction, model_args::Tuple,
4+
observations::ChoiceMap, choice_vol_iter
5+
)
6+
7+
Run enumerative inference over a `model`, given `observations` and an iterator over
8+
choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
9+
to be iterated over. Return an array of traces and associated log-weights with the
10+
same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
11+
corresponds to the log probablity of the volume of sample space that the trace
12+
represents. Also return an estimate of the log marginal likelihood of the
13+
observations (`lml_est`).
14+
15+
All addresses in the `observations` choice map must be sampled by the model when
16+
given the model arguments. The same constraint applies to choice maps enumerated
17+
over by `choice_vol_iter`, which must also avoid sharing addresses with the
18+
`observations`. An iterator over a grid of choice maps and volumes can be constructed
19+
with [`choice_vol_grid`](@ref).
20+
"""
21+
function enumerative_inference(
22+
model::GenerativeFunction{T,U}, model_args::Tuple,
23+
observations::ChoiceMap, choice_vol_iter::I
24+
) where {T,U,I}
25+
if Base.IteratorSize(I) isa Base.HasShape
26+
traces = Array{U}(undef, size(choice_vol_iter))
27+
log_weights = Array{Float64}(undef, size(choice_vol_iter))
28+
elseif Base.IteratorSize(I) isa Base.HasLength
29+
traces = Vector{U}(undef, length(choice_vol_iter))
30+
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
31+
else
32+
choice_vol_iter = collect(choice_vol_iter)
33+
traces = Vector{U}(undef, length(choice_vol_iter))
34+
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
35+
end
36+
for (i, (choices, vol)) in enumerate(choice_vol_iter)
37+
constraints = merge(observations, choices)
38+
(traces[i], log_weight) = generate(model, model_args, constraints)
39+
log_weights[i] = log_weight + log(vol)
40+
end
41+
log_total_weight = logsumexp(log_weights)
42+
log_normalized_weights = log_weights .- log_total_weight
43+
return (traces, log_normalized_weights, log_total_weight)
44+
end
45+
46+
"""
47+
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
48+
49+
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
50+
over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration.
51+
52+
Each `addr` is an address of a random choice, and `vals` are the corresponding
53+
values or intervals to enumerate over. The (optional) `support` denotes whether
54+
each random choice is `:discrete` (default) or `:continuous`. This controls how
55+
the grid is constructed:
56+
- `support = :discrete`: The grid iterates over each value in `vals`.
57+
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the
58+
anchors of 1D intervals whose endpoints are given by `vals`.
59+
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
60+
over the anchors of multi-dimensional regions defined `vals`, which is a tuple
61+
of interval endpoints for each dimension.
62+
Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
63+
The `anchor` keyword argument controls which point in each interval is used as
64+
the anchor (`:left`, `:right`, or `:midpoint`).
65+
66+
The volume `vol` associated with each set of `choices` in the grid is given by
67+
the product of the volumes of each continuous region used to construct those
68+
choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`.
69+
"""
70+
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
71+
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
72+
for spec in grid_specs)
73+
val_iter = Iterators.product(val_iter...)
74+
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
75+
vol_iter = Iterators.product(vol_iter...)
76+
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
77+
return (choicemap(vals...), prod(vols))
78+
end
79+
return choice_vol_iter
80+
end
81+
82+
function expand_grid_spec_to_values(
83+
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1);
84+
anchor::Symbol = :midpoint
85+
) where {N}
86+
if support == :discrete
87+
return ((addr, v) for v in vals)
88+
elseif support == :continuous && N == 1
89+
if anchor == :left
90+
vals = @view(vals[begin:end-1])
91+
elseif anchor == :right
92+
vals = @view(vals[begin+1:end])
93+
else
94+
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2)
95+
end
96+
return ((addr, v) for v in vals)
97+
elseif support == :continuous && N > 1
98+
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
99+
vals = map(vals) do vs
100+
if anchor == :left
101+
vs = @view(vs[begin:end-1])
102+
elseif anchor == :right
103+
vs = @view(vs[begin+1:end])
104+
else
105+
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2)
106+
end
107+
return vs
108+
end
109+
return ((addr, v) for v in Iterators.product(vals...))
110+
else
111+
error("Support must be :discrete or :continuous")
112+
end
113+
end
114+
115+
function expand_grid_spec_to_volumes(
116+
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
117+
) where {N}
118+
if support == :discrete
119+
return ones(length(vals))
120+
elseif support == :continuous && N == 1
121+
return diff(vals)
122+
elseif support == :continuous && N > 1
123+
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
124+
diffs = Iterators.product((diff(vs) for vs in vals)...)
125+
return (prod(ds) for ds in diffs)
126+
else
127+
error("Support must be :discrete or :continuous")
128+
end
129+
end
130+
131+
export enumerative_inference, choice_vol_grid

src/inference/inference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include("hmc.jl")
2121
include("mala.jl")
2222
include("elliptical_slice.jl")
2323

24+
include("enumerative.jl")
2425
include("importance.jl")
2526
include("particle_filter.jl")
2627
include("map_optimize.jl")

test/inference/enumerative.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
@testset "enumerative inference" begin
2+
3+
# polynomial regression model
4+
@gen function poly_model(n::Int, xs)
5+
degree ~ uniform_discrete(1, n)
6+
coeffs = zeros(n+1)
7+
for d in 0:n
8+
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1)
9+
end
10+
ys = zeros(length(xs))
11+
for (i, x) in enumerate(xs)
12+
x_powers = x .^ (0:n)
13+
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree)
14+
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1)
15+
end
16+
return ys
17+
end
18+
19+
# synthetic dataset
20+
coeffs = [0.5, 0.1, -0.5]
21+
xs = collect(0.5:0.5:3.0)
22+
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs]
23+
24+
observations = choicemap()
25+
for (i, y) in enumerate(ys)
26+
observations[(:y, i)] = y
27+
end
28+
29+
# test construction of choicemap-volume grid
30+
grid = choice_vol_grid(
31+
(:degree, 1:2),
32+
((:coeff, 0), -1:0.2:1, :continuous),
33+
((:coeff, 1), -1:0.2:1, :continuous),
34+
((:coeff, 2), -1:0.2:1, :continuous),
35+
anchor = :midpoint
36+
)
37+
38+
@test size(grid) == (2, 10, 10, 10)
39+
@test length(grid) == 2000
40+
41+
choices, vol = first(grid)
42+
@test choices == choicemap(
43+
(:degree, 1),
44+
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
45+
)
46+
@test vol 0.2 * 0.2 * 0.2
47+
48+
test_choices(n::Int, cs) =
49+
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)
50+
51+
@test all(test_choices(2, choices) for (choices, vol) in grid)
52+
@test all(vol (0.2 * 0.2 * 0.2) for (choices, vol) in grid)
53+
54+
# run enumerative inference over grid
55+
traces, log_norm_weights, lml_est =
56+
enumerative_inference(poly_model, (2, xs), observations, grid)
57+
58+
@test size(traces) == (2, 10, 10, 10)
59+
@test length(traces) == 2000
60+
@test all(test_choices(2, tr) for tr in traces)
61+
62+
# test that log-weights are as expected
63+
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces]
64+
lml_expected = logsumexp(log_joint_weights)
65+
@test lml_est lml_expected
66+
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))
67+
68+
# test that polynomial is most likely quadratic
69+
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4))
70+
@test argmax(vec(degree_probs)) == 2
71+
72+
# test that MAP trace recovers the original coefficients
73+
map_trace_idx = argmax(log_norm_weights)
74+
map_trace = traces[map_trace_idx]
75+
@test map_trace[:degree] == 2
76+
@test map_trace[(:coeff, 0)] == 0.5
77+
@test map_trace[(:coeff, 1)] == 0.1
78+
@test map_trace[(:coeff, 2)] == -0.5
79+
80+
end

0 commit comments

Comments
 (0)