|
| 1 | +@testset "enumerative inference" begin |
| 2 | + |
| 3 | + # polynomial regression model |
| 4 | + @gen function poly_model(n::Int, xs) |
| 5 | + degree ~ uniform_discrete(1, n) |
| 6 | + coeffs = zeros(n+1) |
| 7 | + for d in 0:n |
| 8 | + coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1) |
| 9 | + end |
| 10 | + ys = zeros(length(xs)) |
| 11 | + for (i, x) in enumerate(xs) |
| 12 | + x_powers = x .^ (0:n) |
| 13 | + y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree) |
| 14 | + ys[i] = {(:y, i)} ~ normal(y_mean, 0.1) |
| 15 | + end |
| 16 | + return ys |
| 17 | + end |
| 18 | + |
| 19 | + # synthetic dataset |
| 20 | + coeffs = [0.5, 0.1, -0.5] |
| 21 | + xs = collect(0.5:0.5:3.0) |
| 22 | + ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs] |
| 23 | + |
| 24 | + observations = choicemap() |
| 25 | + for (i, y) in enumerate(ys) |
| 26 | + observations[(:y, i)] = y |
| 27 | + end |
| 28 | + |
| 29 | + # test construction of choicemap-volume grid |
| 30 | + grid = choice_vol_grid( |
| 31 | + (:degree, 1:2), |
| 32 | + ((:coeff, 0), -1:0.2:1, :continuous), |
| 33 | + ((:coeff, 1), -1:0.2:1, :continuous), |
| 34 | + ((:coeff, 2), -1:0.2:1, :continuous), |
| 35 | + anchor = :midpoint |
| 36 | + ) |
| 37 | + |
| 38 | + @test size(grid) == (2, 10, 10, 10) |
| 39 | + @test length(grid) == 2000 |
| 40 | + |
| 41 | + choices, log_vol = first(grid) |
| 42 | + @test choices == choicemap( |
| 43 | + (:degree, 1), |
| 44 | + ((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9), |
| 45 | + ) |
| 46 | + @test log_vol ≈ log(0.2^3) |
| 47 | + |
| 48 | + test_choices(n::Int, cs) = |
| 49 | + cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n) |
| 50 | + |
| 51 | + @test all(test_choices(2, choices) for (choices, _) in grid) |
| 52 | + @test all(log_vol ≈ log(0.2^3) for (_, log_vol) in grid) |
| 53 | + |
| 54 | + # run enumerative inference over grid |
| 55 | + traces, log_norm_weights, lml_est = |
| 56 | + enumerative_inference(poly_model, (2, xs), observations, grid) |
| 57 | + |
| 58 | + @test size(traces) == (2, 10, 10, 10) |
| 59 | + @test length(traces) == 2000 |
| 60 | + @test all(test_choices(2, tr) for tr in traces) |
| 61 | + |
| 62 | + # test that log-weights are as expected |
| 63 | + log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces] |
| 64 | + lml_expected = logsumexp(log_joint_weights) |
| 65 | + @test lml_est ≈ lml_expected |
| 66 | + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) |
| 67 | + |
| 68 | + # test that polynomial is most likely quadratic |
| 69 | + degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4)) |
| 70 | + @test argmax(vec(degree_probs)) == 2 |
| 71 | + |
| 72 | + # test that MAP trace recovers the original coefficients |
| 73 | + map_trace_idx = argmax(log_norm_weights) |
| 74 | + map_trace = traces[map_trace_idx] |
| 75 | + @test map_trace[:degree] == 2 |
| 76 | + @test map_trace[(:coeff, 0)] == 0.5 |
| 77 | + @test map_trace[(:coeff, 1)] == 0.1 |
| 78 | + @test map_trace[(:coeff, 2)] == -0.5 |
| 79 | + |
| 80 | + # 2D mixture of normals |
| 81 | + @gen function mixture_model() |
| 82 | + sign ~ bernoulli(0.5) |
| 83 | + mu = sign ? fill(0.5, 2) : fill(-0.5, 2) |
| 84 | + z ~ broadcasted_normal(mu, ones(2)) |
| 85 | + end |
| 86 | + |
| 87 | + # test construction of grid with 2D random variable |
| 88 | + grid = choice_vol_grid( |
| 89 | + (:sign, [false, true]), |
| 90 | + (:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)), |
| 91 | + anchor = :left |
| 92 | + ) |
| 93 | + |
| 94 | + @test size(grid) == (2, 40, 40) |
| 95 | + @test length(grid) == 3200 |
| 96 | + |
| 97 | + choices, log_vol = first(grid) |
| 98 | + @test choices == choicemap((:sign, false), (:z, [-2.0, -2.0])) |
| 99 | + @test log_vol ≈ log(0.1^2) |
| 100 | + |
| 101 | + @test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid) |
| 102 | + @test all(log_vol ≈ log(0.1^2) for (_, log_vol) in grid) |
| 103 | + |
| 104 | + # run enumerative inference over grid |
| 105 | + traces, log_norm_weights, lml_est = |
| 106 | + enumerative_inference(mixture_model, (), choicemap(), grid) |
| 107 | + |
| 108 | + @test size(traces) == (2, 40, 40) |
| 109 | + @test length(traces) == 3200 |
| 110 | + @test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces) |
| 111 | + |
| 112 | + # test that log-weights are as expected |
| 113 | + function expected_logpdf(tr) |
| 114 | + x, y = tr[:z] |
| 115 | + mu = tr[:sign] ? 0.5 : -0.5 |
| 116 | + return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0) |
| 117 | + end |
| 118 | + |
| 119 | + log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces] |
| 120 | + lml_expected = logsumexp(log_joint_weights) |
| 121 | + @test lml_est ≈ lml_expected |
| 122 | + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) |
| 123 | + |
| 124 | + # test that maximal log-weights are at modes |
| 125 | + max_log_weight = maximum(log_norm_weights) |
| 126 | + max_idxs = findall(log_norm_weights .== max_log_weight) |
| 127 | + |
| 128 | + max_trace_1 = traces[max_idxs[1]] |
| 129 | + @test max_trace_1[:sign] == false |
| 130 | + @test max_trace_1[:z] == [-0.5, -0.5] |
| 131 | + |
| 132 | + max_trace_2 = traces[max_idxs[2]] |
| 133 | + @test max_trace_2[:sign] == true |
| 134 | + @test max_trace_2[:z] == [0.5, 0.5] |
| 135 | + |
| 136 | +end |
0 commit comments