|
77 | 77 | @test map_trace[(:coeff, 1)] == 0.1 |
78 | 78 | @test map_trace[(:coeff, 2)] == -0.5 |
79 | 79 |
|
| 80 | + # 2D mixture of normals |
| 81 | + @gen function mixture_model() |
| 82 | + sign ~ bernoulli(0.5) |
| 83 | + mu = sign ? fill(0.5, 2) : fill(-0.5, 2) |
| 84 | + z ~ broadcasted_normal(mu, ones(2)) |
| 85 | + end |
| 86 | + |
| 87 | + # test construction of grid with 2D random variable |
| 88 | + grid = choice_vol_grid( |
| 89 | + (:sign, [false, true]), |
| 90 | + (:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)), |
| 91 | + anchor = :left |
| 92 | + ) |
| 93 | + |
| 94 | + @test size(grid) == (2, 40, 40) |
| 95 | + @test length(grid) == 3200 |
| 96 | + |
| 97 | + choices, log_vol = first(grid) |
| 98 | + @test choices == choicemap((:sign, false), (:z, [-2.0, -2.0])) |
| 99 | + @test log_vol ≈ log(0.1^2) |
| 100 | + |
| 101 | + @test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid) |
| 102 | + @test all(log_vol ≈ log(0.1^2) for (_, log_vol) in grid) |
| 103 | + |
| 104 | + # run enumerative inference over grid |
| 105 | + traces, log_norm_weights, lml_est = |
| 106 | + enumerative_inference(mixture_model, (), choicemap(), grid) |
| 107 | + |
| 108 | + @test size(traces) == (2, 40, 40) |
| 109 | + @test length(traces) == 3200 |
| 110 | + @test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces) |
| 111 | + |
| 112 | + # test that log-weights are as expected |
| 113 | + function expected_logpdf(tr) |
| 114 | + x, y = tr[:z] |
| 115 | + mu = tr[:sign] ? 0.5 : -0.5 |
| 116 | + return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0) |
| 117 | + end |
| 118 | + |
| 119 | + log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces] |
| 120 | + lml_expected = logsumexp(log_joint_weights) |
| 121 | + @test lml_est ≈ lml_expected |
| 122 | + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) |
| 123 | + |
| 124 | + # test that maximal log-weights are at modes |
| 125 | + max_log_weight = maximum(log_norm_weights) |
| 126 | + max_idxs = findall(log_norm_weights .== max_log_weight) |
| 127 | + |
| 128 | + max_trace_1 = traces[max_idxs[1]] |
| 129 | + @test max_trace_1[:sign] == false |
| 130 | + @test max_trace_1[:z] == [-0.5, -0.5] |
| 131 | + |
| 132 | + max_trace_2 = traces[max_idxs[2]] |
| 133 | + @test max_trace_2[:sign] == true |
| 134 | + @test max_trace_2[:z] == [0.5, 0.5] |
| 135 | + |
80 | 136 | end |
0 commit comments