|
2 | 2 | # Test transformed distributions |
3 | 3 | @dist f(x) = exp(normal(x, 0.001)) |
4 | 4 | @test isapprox(1, f(0); atol = 5) |
| 5 | + @test isapprox(logpdf(f, 1., 0.), logpdf(normal, 0., 0., 0.001)) |
| 6 | + |
| 7 | + # Test gradients of transformed distributions |
| 8 | + @dist shifted_normal(mu, sigma) = normal(mu, sigma) + 1. |
| 9 | + @test logpdf(shifted_normal, 1., 0., 1.) == logpdf(normal, 0., 0., 1.) |
| 10 | + @test logpdf_grad(shifted_normal, 0., 0., 1.) == logpdf_grad(normal, -1., 0., 1.) |
| 11 | + |
| 12 | + # Test gradients of transformed distributions with no parameters |
| 13 | + @dist shifted_std_normal() = normal(0., 1.) + 1. |
| 14 | + @test logpdf(shifted_std_normal, 1.) == logpdf(normal, 0., 0., 1.) |
| 15 | + @test logpdf_grad(shifted_std_normal, 0.) == (logpdf_grad(normal, -1., 0., 1.)[1],) |
| 16 | + |
| 17 | + # Test gradients of multivariate distributions |
| 18 | + @dist vec_normal(mu, sigma) = broadcasted_normal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0)) |
| 19 | + @test logpdf(vec_normal, zeros(2), zeros(2), ones(2)) == |
| 20 | + logpdf(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2)) |
| 21 | + transformed_grads = logpdf_grad(vec_normal, zeros(2), zeros(2), ones(2)) |
| 22 | + orig_grads = logpdf_grad(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2)) |
| 23 | + @test transformed_grads[1] == orig_grads[1] |
| 24 | + @test transformed_grads[2] == orig_grads[2] |
| 25 | + @test transformed_grads[3] == 2.0 * orig_grads[3] |
| 26 | + |
| 27 | + # Test gradients of multivariate distributions with multi-dimensional arguments |
| 28 | + @dist transformed_mvnormal(mu, sigma) = mvnormal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0)) |
| 29 | + @test logpdf(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2)) == |
| 30 | + logpdf(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2)) |
| 31 | + transformed_grads = logpdf_grad(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2)) |
| 32 | + orig_grads = logpdf_grad(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2)) |
| 33 | + @test transformed_grads[1] == orig_grads[1] |
| 34 | + @test transformed_grads[2] == orig_grads[2] |
| 35 | + @test transformed_grads[3] == 2.0 * orig_grads[3] |
5 | 36 |
|
6 | 37 | # Test relabeled distributions with labels provided as an Array |
7 | 38 | @dist labeled_cat(labels, probs) = labels[categorical(probs)] |
8 | 39 | @test labeled_cat([:a, :b], [0., 1.]) == :b |
9 | 40 | @test isapprox(logpdf(labeled_cat, :b, [:a, :b], [0.5, 0.5]), log(0.5)) |
| 41 | + @test logpdf_grad(labeled_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...) |
10 | 42 | @test logpdf(labeled_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf |
11 | 43 |
|
12 | 44 | # Test relabeled distributions with labels provided in a Dict |
13 | 45 | dict = Dict(1 => :a, 2 => :b) |
14 | 46 | @dist dict_cat(probs) = dict[categorical(probs)] |
15 | 47 | @test dict_cat([0., 1.]) == :b |
16 | 48 | @test isapprox(logpdf(dict_cat, :b, [0.5, 0.5]), log(0.5)) |
| 49 | + @test logpdf_grad(dict_cat, :b, [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5]) |
17 | 50 | @test logpdf(dict_cat, :c, [0.5, 0.5]) == -Inf |
18 | 51 |
|
19 | 52 | # Test relabeled distributions with Enum labels |
|
22 | 55 | @test enum_cat([0., 1.]) == orange |
23 | 56 | @test isapprox(logpdf(enum_cat, orange, [0.5, 0.5]), log(0.5)) |
24 | 57 | @test logpdf(enum_cat, orange, [1.0]) == -Inf |
| 58 | + @test logpdf_grad(enum_cat, orange, [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5]) |
25 | 59 |
|
26 | 60 | # Regression test for https://github.com/probcomp/Gen/issues/253 |
27 | 61 | @dist real_minus_uniform(a, b) = 1 - Gen.uniform(a, b) |
28 | 62 | @test real_minus_uniform(1, 2) < 0 |
29 | 63 | @test logpdf(real_minus_uniform, -0.5, 1, 2) == 0.0 |
| 64 | + @test logpdf_grad(real_minus_uniform, -0.5, 1, 2) == logpdf_grad(uniform, 1.5, 1, 2) |
30 | 65 | end |
31 | 66 |
|
32 | 67 | # User-defined type for testing purposes |
|
40 | 75 | @test symbol_cat([:a, :b], [0., 1.]) == :b |
41 | 76 | @test_throws MethodError symbol_cat(["a", "b"], [0., 1.]) |
42 | 77 | @test logpdf(symbol_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf |
| 78 | + @test logpdf_grad(symbol_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...) |
43 | 79 | @test_throws MethodError logpdf(symbol_cat, "c", [:a, :b], [0.5, 0.5]) |
44 | 80 |
|
45 | 81 | # Test typed parameters |
46 | 82 | @dist int_bounded_uniform(low::Int, high::Int) = uniform(low, high) |
47 | 83 | @test 0.0 <= int_bounded_uniform(0, 1) <= 1 |
48 | 84 | @test_throws MethodError int_bounded_uniform(-0.5, 0.5) |
49 | 85 | @test logpdf(int_bounded_uniform, 0.5, 0, 1) == 0 |
| 86 | + @test logpdf_grad(int_bounded_uniform, 0.5, 0, 1) == logpdf_grad(uniform, 0.5, 0, 1) |
50 | 87 | @test_throws MethodError logpdf(int_bounded_uniform, 0.0, -0.5, 0.5) |
51 | 88 |
|
52 | 89 | # Test relabeled distributions with user-defined types |
|
0 commit comments