Skip to content

Commit c5404a1

Browse files
committed
bump Bijectors.jl compat bounds and replace forward with rand_and_logjac
1 parent bb7e85c commit c5404a1

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.1.6"
3+
version = "0.2.0"
44

55
[deps]
66
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
@@ -17,7 +17,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1717
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1818

1919
[compat]
20-
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10"
20+
Bijectors = "0.11, 0.12"
2121
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
2222
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
2323
DocStringExtensions = "0.8, 0.9"

src/AdvancedVI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ end
1919
const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
2020

2121
include("ad.jl")
22+
include("utils.jl")
2223

2324
using Requires
2425
function __init__()

src/advi.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ function (elbo::ELBO)(
8181
# = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃))
8282
# = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃))
8383

84-
# But our `forward(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
85-
_, z, logjac, _ = forward(rng, q)
84+
# But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
85+
z, logjac = rand_and_logjac(rng, q)
8686
res = (logπ(z) + logjac) / num_samples
8787

8888
if q isa TransformedDistribution
@@ -92,7 +92,7 @@ function (elbo::ELBO)(
9292
end
9393

9494
for i = 2:num_samples
95-
_, z, logjac, _ = forward(rng, q)
95+
z, logjac = rand_and_logjac(rng, q)
9696
res += (logπ(z) + logjac) / num_samples
9797
end
9898

src/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Distributions
2+
3+
using Random: Random
4+
using Bijectors: Bijectors
5+
6+
7+
function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution)
8+
x = rand(rng, dist)
9+
return x, zero(eltype(x))
10+
end
11+
12+
function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution)
13+
x = rand(rng, dist.dist)
14+
y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
15+
return y, logjac
16+
end

0 commit comments

Comments
 (0)