Skip to content

Commit 1146ca0

Browse files
committed
Add multi-variate enumeration test case.
1 parent 7104004 commit 1146ca0

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/inference/enumerative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function expand_grid_spec_to_values(
107107
end
108108
return vs
109109
end
110-
return ((addr, v) for v in Iterators.product(vals...))
110+
return ((addr, collect(v)) for v in Iterators.product(vals...))
111111
else
112112
error("Support must be :discrete or :continuous")
113113
end

test/inference/enumerative.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,60 @@
7777
@test map_trace[(:coeff, 1)] == 0.1
7878
@test map_trace[(:coeff, 2)] == -0.5
7979

80+
# 2D mixture of normals
81+
@gen function mixture_model()
82+
sign ~ bernoulli(0.5)
83+
mu = sign ? fill(0.5, 2) : fill(-0.5, 2)
84+
z ~ broadcasted_normal(mu, ones(2))
85+
end
86+
87+
# test construction of grid with 2D random variable
88+
grid = choice_vol_grid(
89+
(:sign, [false, true]),
90+
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)),
91+
anchor = :left
92+
)
93+
94+
@test size(grid) == (2, 40, 40)
95+
@test length(grid) == 3200
96+
97+
choices, log_vol = first(grid)
98+
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0]))
99+
@test log_vol log(0.1^2)
100+
101+
@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid)
102+
@test all(log_vol log(0.1^2) for (_, log_vol) in grid)
103+
104+
# run enumerative inference over grid
105+
traces, log_norm_weights, lml_est =
106+
enumerative_inference(mixture_model, (), choicemap(), grid)
107+
108+
@test size(traces) == (2, 40, 40)
109+
@test length(traces) == 3200
110+
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces)
111+
112+
# test that log-weights are as expected
113+
function expected_logpdf(tr)
114+
x, y = tr[:z]
115+
mu = tr[:sign] ? 0.5 : -0.5
116+
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0)
117+
end
118+
119+
log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces]
120+
lml_expected = logsumexp(log_joint_weights)
121+
@test lml_est lml_expected
122+
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))
123+
124+
# test that maximal log-weights are at modes
125+
max_log_weight = maximum(log_norm_weights)
126+
max_idxs = findall(log_norm_weights .== max_log_weight)
127+
128+
max_trace_1 = traces[max_idxs[1]]
129+
@test max_trace_1[:sign] == false
130+
@test max_trace_1[:z] == [-0.5, -0.5]
131+
132+
max_trace_2 = traces[max_idxs[2]]
133+
@test max_trace_2[:sign] == true
134+
@test max_trace_2[:z] == [0.5, 0.5]
135+
80136
end

0 commit comments

Comments
 (0)