Skip to content

Commit 2b04312

Browse files
authored
use Bijections.jl for variable transformations, rather than Transform… (#11)
* use Bijections.jl for variable transformations, rather than TransformVariables allows more flexibility for specifying and combining transformations combinations applicable (also the invers) to vectors (rather than tuples) Better support for AD on GPU
1 parent f7cd1a3 commit 2b04312

File tree

9 files changed

+64
-78
lines changed

9 files changed

+64
-78
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Thomas Wutzler <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

66
[deps]
7+
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
78
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
89
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1516
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1617
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
17-
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
1818
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1919

2020
[weakdeps]
@@ -28,6 +28,7 @@ HybridVariationalInferenceLuxExt = "Lux"
2828
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
2929

3030
[compat]
31+
Bijectors = "0.15.4"
3132
BlockDiagonals = "0.1.42"
3233
CUDA = "5.5.2"
3334
ChainRulesCore = "1.25"
@@ -41,7 +42,6 @@ Random = "1.10.0"
4142
SimpleChains = "0.4"
4243
StatsBase = "0.34.4"
4344
StatsFuns = "1.3.2"
44-
TransformVariables = "0.8.10"
4545
Zygote = "0.6.73"
4646
julia = "1.10"
4747

dev/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
23
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
45
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
@@ -15,7 +16,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1516
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
18-
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
1919
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2020
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2121
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

src/HybridVariationalInference.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@ using GPUArraysCore
88
using LinearAlgebra
99
using CUDA
1010
using ChainRulesCore
11-
using TransformVariables
11+
using Bijectors
1212
using Zygote # Zygote.@ignore CUDA.randn
1313
using BlockDiagonals
1414

15-
export inverse_ca
16-
include("util_transformvariables.jl")
17-
1815
export ComponentArrayInterpreter, flatten1, get_concrete
1916
include("ComponentArrayInterpreter.jl")
2017

src/elbo.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ expected value of the likelihood of observations.
1313
interpreted by interpreters.μP_ϕg_unc and interpreters.PMs
1414
- y_ob: matrix of observations (n_obs x n_site_batch)
1515
- x: matrix of covariates (n_cov x n_site_batch)
16-
- transPMs: Transformations with components P, Ms, similar to interpreters
16+
- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params
1717
- n_MC: number of MonteCarlo samples from the distribution of parameters to simulate
1818
using the mechanistic model f.
1919
- logσ2y: observation uncertainty (log of the variance)
@@ -27,7 +27,7 @@ function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractM
2727
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
2828
#ζi = first(eachcol(ζs_cpu))
2929
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
30-
y_pred_i, logjac = predict_y(ζi, f, transPMs)
30+
y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs)
3131
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
3232
nLy1 - logjac
3333
end) / n_MC
@@ -51,10 +51,12 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret
5151
n_site = size(xM, 2)
5252
intm_PMs_gen = get_ca_int_PMs(n_site)
5353
trans_PMs_gen = get_transPMs(n_site)
54+
interpreters_gen = (; interpreters..., PMs = intm_PMs_gen)
5455
ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM),
55-
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred)
56+
interpreters_gen; n_MC = n_sample_pred)
5657
ζs_cpu = gpu_data_handler(ζs) #
57-
y_pred = stack(map-> first(predict_y(ζ, f, trans_PMs_gen)), eachcol(ζs_cpu)));
58+
y_pred = stack(map-> first(predict_y(
59+
ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu)));
5860
y_pred
5961
end
6062

@@ -77,9 +79,9 @@ function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix,
7779
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
7880
ζ = stack(map(eachcol(ζ_resid)) do r
7981
rc = interpreters.PMs(r)
80-
ζP = μ_ζP .+ rc.θP
82+
ζP = μ_ζP .+ rc.P
8183
μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g
82-
ζMs = μ_ζMs .+ rc.θMs
84+
ζMs = μ_ζMs .+ rc.Ms
8385
vcat(ζP, vec(ζMs))
8486
end)
8587
ζ, σ
@@ -166,9 +168,11 @@ Steps:
166168
- transform the parameters to original constrained space
167169
- Applies the mechanistic model for each site
168170
"""
169-
function predict_y(ζi, f, transPMs)
170-
θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating
171-
θc = CA.ComponentVector(θtup)
171+
function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter)
172+
# θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating
173+
# θc = CA.ComponentVector(θtup)
174+
θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating
175+
θc = int_PMs(θ)
172176
# TODO provide xP
173177
xP = fill((), size(θc.Ms,2))
174178
y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU

src/init_hybrid_params.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ Returns a NamedTuple of
1414
- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
1515
- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator`
1616
- `n_batch`: the number of sites to predicted in each mini-batch
17-
- `transP`, `transM`: the Transformations for the global and site-dependent parameters
17+
- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
18+
parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.
19+
Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ),
20+
because this direction is used much more often.
1821
"""
19-
function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
22+
function init_hybrid_params(θP, θM, ϕg, n_batch;
23+
transP=elementwise(identity), transM=elementwise(identity))
2024
n_θP = length(θP)
2125
n_θM = length(θM)
2226
n_ϕg = length(ϕg)
27+
# check translating parameters - can match length?
28+
_ = Bijectors.inverse(transP)(θP)
29+
_ = Bijectors.inverse(transM)(θM)
2330
# zero correlation matrices
2431
ρsP = zeros(sum(1:(n_θP - 1)))
2532
ρsM = zeros(sum(1:(n_θM - 1)))
@@ -28,39 +35,35 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
2835
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
2936
ρsP,
3037
ρsM)
31-
ϕt = CA.ComponentVector(;
32-
μP = θP,
38+
ϕ = CA.ComponentVector(;
39+
μP = inverse(transP)(θP),
3340
ϕg = ϕg,
3441
unc = ϕunc0);
3542
#
3643
get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM
3744
function get_transPMs_inner(n_site)
38-
transPMs = as(
39-
(P = as(Array, transP, n_θP),
40-
Ms = as(Array, transM, n_θM, n_site)))
45+
transMs = ntuple(i -> transM, n_site)
46+
ranges = vcat([1:n_θP], [(n_θP + i0*n_θM) .+ (1:n_θM) for i0 in 0:(n_site-1)])
47+
transPMs = Stacked((transP, transMs...), ranges)
48+
transPMs
4149
end
4250
end
4351
transPMs_batch = get_transPMs(n_batch)
44-
trans_gu = as(
45-
(μP = as(Array, asℝ₊, n_θP),
46-
ϕg = as(Array, n_ϕg),
47-
unc = as(Array, length(ϕunc0))))
48-
ϕ = inverse_ca(trans_gu, ϕt)
49-
# trans_g = as(
50-
# (μP = as(Array, asℝ₊, n_θP),
51-
# ϕg = as(Array, n_ϕg)))
52-
#
52+
# ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0)))
53+
# inv_trans_gu = Stacked(
54+
# (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges))
55+
# ϕ = inv_trans_gu(CA.getdata(ϕt))
5356
get_ca_int_PMs = let
5457
function get_ca_int_PMs_inner(n_site)
55-
ComponentArrayInterpreter(CA.ComponentVector(; θP,
56-
θMs = CA.ComponentMatrix(
58+
ComponentArrayInterpreter(CA.ComponentVector(; P=θP,
59+
Ms = CA.ComponentMatrix(
5760
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site))))
5861
end
5962

6063
end
6164
interpreters = map(get_concrete,
6265
(;
63-
μP_ϕg_unc = ComponentArrayInterpreter(ϕt),
66+
μP_ϕg_unc = ComponentArrayInterpreter(ϕ),
6467
PMs = get_ca_int_PMs(n_batch),
6568
unc = ComponentArrayInterpreter(ϕunc0)
6669
))

src/util_transformvariables.jl

Lines changed: 0 additions & 21 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
34
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
45
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
56
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -16,6 +17,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19-
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
2020
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2121
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

test/test_elbo.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using StableRNGs
77
using Random
88
using SimpleChains
99
using ComponentArrays: ComponentArrays as CA
10-
using TransformVariables
10+
#using TransformVariables
11+
using Bijectors
1112
using Zygote
1213
using CUDA
1314
using GPUArraysCore: GPUArraysCore
@@ -30,8 +31,11 @@ f = gen_hybridcase_PBmodel(case; scenario)
3031

3132
logσ2y = 2 .* log.(σ_o)
3233
n_MC = 3
34+
transP = elementwise(exp)
35+
transM = Stacked(elementwise(identity), elementwise(exp))
36+
#transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch
3337
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
34-
θP_true, θMs_true[:, 1], ϕg0, n_batch; transP = asℝ₊, transM = asℝ₊);
38+
θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM);
3539
ϕ_ini = ϕ
3640

3741
() -> begin
@@ -163,7 +167,11 @@ if CUDA.functional()
163167
end
164168

165169
@testset "predict_gf cpu" begin
166-
n_sample_pred = 200
170+
n_sample_pred = n_site = 200
171+
intm_PMs_gen = get_ca_int_PMs(n_site)
172+
trans_PMs_gen = get_transPMs(n_site)
173+
@test length(intm_PMs_gen) == 402
174+
@test trans_PMs_gen.length_in == 402
167175
y_pred = predict_gf(rng, g, f, ϕ_ini, xM, map(get_concrete, interpreters);
168176
get_transPMs, get_ca_int_PMs, n_sample_pred)
169177
@test y_pred isa Array

test/test_sample_zeta.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using GPUArraysCore: GPUArraysCore
1010
using Random
1111
#using SimpleChains
1212
using ComponentArrays: ComponentArrays as CA
13-
using TransformVariables
13+
using Bijectors
1414

1515
#CUDA.device!(4)
1616
rng = StableRNG(111)
@@ -23,37 +23,32 @@ scenario = (:default,)
2323

2424
@testset "test_sample_zeta" begin
2525
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
26-
) = gen_hybridcase_synthetic(case, rng; scenario)
26+
) = gen_hybridcase_synthetic(case, rng; scenario)
2727

2828
# n_site = 2
2929
# n_θP, n_θM = length(θ_true.θP), length(θ_true.θM)
3030
# σ_θM = θ_true.θM .* 0.1 # 10% around expected
3131
# θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM
3232

3333
# set to 0.02 rather than zero for debugging non-zero correlations
34-
ρsP = zeros(sum(1:(n_θP - 1))) .+ 0.02
35-
ρsM = zeros(sum(1:(n_θM - 1))) .+ 0.02
34+
ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02
35+
ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02
3636

3737
ϕunc = CA.ComponentVector(;
38-
logσ2_logP = fill(-10.0, n_θP),
39-
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
38+
logσ2_logP=fill(-10.0, n_θP),
39+
coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
4040
ρsP,
4141
ρsM)
4242

43-
transPMs = as(
44-
(P = as(Array, asℝ₊, n_θP),
45-
Ms = as(Array, asℝ₊, n_θM, n_site)))
4643
θ_true = θ = CA.ComponentVector(;
47-
P = θP_true,
48-
Ms = θMs_true)
49-
transPMs = as((
50-
P = as(Array, asℝ₊, n_θP),
51-
Ms = as(Array, asℝ₊, n_θM, n_site)))
52-
ζ_true = inverse_ca(transPMs, θ_true)
53-
ϕ_true = vcat(ζ_true, CA.ComponentVector(unc = ϕunc))
54-
ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc = ϕunc))
44+
P=θP_true,
45+
Ms=θMs_true)
46+
transPMs = elementwise(exp) # all parameters on LogNormal scale
47+
ζ_true = inverse(transPMs)(θ_true)
48+
ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc))
49+
ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc))
5550

56-
interpreters = (; pmu = ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs)
51+
interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs)
5752

5853
n_MC = 3
5954
@testset "sample_ζ_norm0 cpu" begin

0 commit comments

Comments
 (0)