|
1 | 1 | import DataStructures: OrderedDict |
| 2 | +import LinearAlgebra: diagm |
2 | 3 |
|
3 | 4 | @testset "bernoulli" begin |
4 | 5 |
|
|
233 | 234 | @test isapprox(actual[3][2, 2], finite_diff_mat_sym(f, args, 3, 2, 2, dx)) |
234 | 235 | end |
235 | 236 |
|
| 237 | +@testset "dirichlet" begin |
| 238 | + x = dirichlet([1., 1., 1., 1.]) |
| 239 | + @test length(x) == 4 |
| 240 | + @test isapprox(sum(x), 1.) |
| 241 | + |
| 242 | + # bounds checking |
| 243 | + @test logpdf(dirichlet, [0., 0], [1., 1.]) == -Inf |
| 244 | + @test logpdf(dirichlet, [1., 1.], [1., 1.]) == -Inf |
| 245 | + @test logpdf(dirichlet, [2., -1], [1., 1.]) == -Inf |
| 246 | + @test logpdf(dirichlet, [.5, .5], [-1., 1.]) == -Inf |
| 247 | + @test logpdf(dirichlet, [.5, .5], [-1., 1.]) == -Inf |
| 248 | + @test logpdf(dirichlet, [0., 1], [1., 1.]) != -Inf |
| 249 | + |
| 250 | + @test isapprox(logpdf(dirichlet, [.01, .99], [2., 2.]), |
| 251 | + Distributions.logpdf(Distributions.Dirichlet([2., 2.]), [.01, .99])) |
| 252 | + @test isapprox(logpdf(dirichlet, [.01, .99], [1., 4.]), |
| 253 | + Distributions.logpdf(Distributions.Dirichlet([1., 4.]), [.01, .99])) |
| 254 | + @test isapprox(logpdf(dirichlet, [.01, .99], [.01, .01]), |
| 255 | + Distributions.logpdf(Distributions.Dirichlet([.01, .01]), [.01, .99])) |
| 256 | + |
| 257 | + # for d > 2 |
| 258 | + @test isapprox(logpdf(dirichlet, [.2, .2, .6], [2., 2., 4.]), |
| 259 | + Distributions.logpdf(Distributions.Dirichlet([2., 2., 4.]), [.2, .2, .6])) |
| 260 | + |
| 261 | + function softmax(x) |
| 262 | + exp.(x) / sum(exp.(x)) |
| 263 | + end |
| 264 | + |
| 265 | + function softmax_grad(x) |
| 266 | + diagm(x) .- (x .* x') |
| 267 | + end |
| 268 | + |
| 269 | + f = (x, alpha) -> logpdf(dirichlet, x, alpha) |
| 270 | + f_normalized = (x, alpha) -> logpdf(dirichlet, softmax(x), alpha) |
| 271 | + |
| 272 | + args = ([0., 0., 0., 0.], [1., 2., 3., 3.]) |
| 273 | + normalized_args = ([.25, .25, .25, .25], [1., 2., 3., 3.]) |
| 274 | + |
| 275 | + actual = logpdf_grad(dirichlet, normalized_args...) |
| 276 | + |
| 277 | + # gradients with respect to x |
| 278 | + actual_x_grad = actual[1]' * softmax_grad(normalized_args[1]) |
| 279 | + |
| 280 | + @test isapprox(actual_x_grad[1], finite_diff_vec(f_normalized, args, 1, 1, dx)) |
| 281 | + @test isapprox(actual_x_grad[2], finite_diff_vec(f_normalized, args, 1, 2, dx)) |
| 282 | + @test isapprox(actual_x_grad[3], finite_diff_vec(f_normalized, args, 1, 3, dx)) |
| 283 | + @test isapprox(actual_x_grad[4], finite_diff_vec(f_normalized, args, 1, 4, dx)) |
| 284 | + |
| 285 | + # gradients with respect to alpha |
| 286 | + @test isapprox(actual[2][1], finite_diff_vec(f, normalized_args, 2, 1, dx)) |
| 287 | + @test isapprox(actual[2][2], finite_diff_vec(f, normalized_args, 2, 2, dx)) |
| 288 | + @test isapprox(actual[2][3], finite_diff_vec(f, normalized_args, 2, 3, dx)) |
| 289 | + @test isapprox(actual[2][4], finite_diff_vec(f, normalized_args, 2, 4, dx)) |
| 290 | +end |
| 291 | + |
236 | 292 | @testset "uniform" begin |
237 | 293 |
|
238 | 294 | # random |
|
0 commit comments