Skip to content

Commit 943e60b

Browse files
authored
adapt ELBO to updated derivation (#9)
with entropy over standard Normal but log_det of the multiplcation with Cholesky factor.
1 parent 4012634 commit 943e60b

File tree

7 files changed

+42
-31
lines changed

7 files changed

+42
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ CUDA = "5.5.2"
3333
ChainRulesCore = "1.25"
3434
Combinatorics = "1.0.2"
3535
ComponentArrays = "0.15.19"
36-
Flux = "v0.15.2"
36+
Flux = "v0.15.2, 0.16"
3737
GPUArraysCore = "0.1, 0.2"
3838
LinearAlgebra = "1.10.0"
3939
Lux = "1.4.2"

dev/doubleMM.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ end
186186
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
187187
end
188188

189-
ϕ = ϕ_true |> Flux.gpu;
189+
ϕ = ϕ_ini |> Flux.gpu;
190190
xM_gpu = xM |> Flux.gpu;
191191
g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario);
192192

@@ -213,22 +213,23 @@ g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario
213213
g_flux = g_luxs
214214
end
215215

216-
function fcost(ϕ)
217-
neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o[:, 1:n_batch],
218-
xM_gpu[:, 1:n_batch], transPMs_batch, map(get_concrete, interpreters);
216+
function fcost, xM, y_o)
217+
neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o,
218+
xM, transPMs_batch, map(get_concrete, interpreters);
219219
n_MC = 8, logσ2y = logσ2y)
220220
end
221-
fcost(ϕ)
221+
fcost, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch])
222222
#Zygote.gradient(fcost, ϕ) |> cpu;
223-
gr = Zygote.gradient(fcost, CA.getdata(ϕ));
223+
gr = Zygote.gradient(fcost,
224+
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
224225
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
225226

226227
train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch)
227228

228229
optf = Optimization.OptimizationFunction(
229230
(ϕ, data) -> begin
230231
xM, y_o = data
231-
fcost(ϕ)
232+
fcost, xM, y_o)
232233
# neg_elbo_transnorm_gf(
233234
# rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
234235
# map(get_concrete, interpreters); n_MC = 5, logσ2y)
@@ -259,22 +260,27 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
259260
# hard to estimate for original very small theta's but otherwise good
260261

261262
# test predicting correct obs-uncertainty of predictive posterior
262-
# TODO reuse g_flux rather than g
263263
n_sample_pred = 200
264+
264265
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters;
265266
get_transPMs, get_ca_int_PMs, n_sample_pred);
266267
size(y_pred) # n_obs x n_site, n_sample_pred
267268

268-
σ_o_post = dropdims(std(y_pred; dims = 3), dims=3)
269+
σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
269270

270271
#describe(σ_o_post)
271272
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),
272273
mean(σ_o_post, dims = 2), sqrt.(mean(abs2, σ_o_post, dims = 2)))
274+
# VI predicted uncertainty is smaller than HMC predicted one
273275
mean_y_pred = map(mean, eachslice(y_pred; dims = (1, 2)))
274276
#describe(mean_y_pred - y_o)
275277
histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_true)
276278

277279
# look at θP, θM1 of first site
280+
intm_PMs_gen = get_ca_int_PMs(n_site)
281+
ζs, _σ = HVI.generate_ζ(rng, g_flux, f, res.u, xM_gpu,
282+
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred);
283+
ζs = ζs |> Flux.cpu;
278284
θPM = vcat(θP_true, θMs_true[:, 1])
279285
intm = ComponentArrayInterpreter(θPM, (n_sample_pred,))
280286
ζs1c = intm(ζs[1:length(θPM), :])

src/elbo.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,22 @@ expected value of the likelihood of observations.
2020
"""
2121
function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractMatrix,
2222
transPMs, interpreters::NamedTuple;
23-
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler())
24-
ζs, logdetΣ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
23+
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(),
24+
entropyN = 0.0,
25+
)
26+
ζs, σ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
2527
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
2628
#ζi = first(eachcol(ζs_cpu))
2729
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
2830
y_pred_i, logjac = predict_y(ζi, f, transPMs)
2931
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
3032
nLy1 - logjac
3133
end) / n_MC
32-
ent = entropy_MvNormal(size(ζs, 1), logdetΣ) # defined in logden_normal
33-
nLy - ent
34+
logdet_jacT2 = sum_log_σ = sum(log.(σ))
35+
# logdet_jacT2 = -sum_log_σ # log Prod(1/σ_i) = -sum log σ_i
36+
# logdetΣ = 2 * sum_log_σ # log Prod(σ_i²) = 2* sum log σ_i
37+
# ent = entropy_MvNormal(size(ζs, 1), logdetΣ) # defined in logden_normal
38+
nLy - logdet_jacT2 - entropyN
3439
end
3540

3641
"""
@@ -45,17 +50,17 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret
4550
gpu_data_handler=get_default_GPUHandler())
4651
n_site = size(xM, 2)
4752
intm_PMs_gen = get_ca_int_PMs(n_site)
48-
tans_PMs_gen = get_transPMs(n_site)
53+
trans_PMs_gen = get_transPMs(n_site)
4954
ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM),
5055
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred)
5156
ζs_cpu = gpu_data_handler(ζs) #
52-
y_pred = stack(map-> first(predict_y(ζ, f, tans_PMs_gen)), eachcol(ζs_cpu)));
57+
y_pred = stack(map-> first(predict_y(ζ, f, trans_PMs_gen)), eachcol(ζs_cpu)));
5358
y_pred
5459
end
5560

5661
"""
57-
Generate samples of (inv-transformed) model parameters, ζ, and Log-Determinant
58-
of their distribution.
62+
Generate samples of (inv-transformed) model parameters, ζ,
63+
and the vector of standard deviations, σ, i.e. the diagonal of the cholesky-factor.
5964
6065
Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0`
6166
to the means extracted from parameters and predicted by the machine learning
@@ -68,21 +73,21 @@ function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix,
6873
μ_ζP = ϕc.μP
6974
ϕg = ϕc.ϕg
7075
μ_ζMs0 = g(x, ϕg) # TODO provide μ_ζP to g
71-
ζ_resid, logdetΣ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC)
72-
#ζ_resid, logdetΣ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
76+
ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC)
77+
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
7378
ζ = stack(map(eachcol(ζ_resid)) do r
7479
rc = interpreters.PMs(r)
7580
ζP = μ_ζP .+ rc.θP
7681
μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g
7782
ζMs = μ_ζMs .+ rc.θMs
7883
vcat(ζP, vec(ζMs))
7984
end)
80-
ζ, logdetΣ
85+
ζ, σ
8186
end
8287

8388
"""
8489
Extract relevant parameters from θ and return n_MC generated draws
85-
together with the logdet of the transformation.
90+
together with the vector of standard deviations, σ.
8691
8792
Necessary typestable information on number of compponents are provided with
8893
ComponentMarshellers
@@ -115,9 +120,9 @@ function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::Ab
115120
# need to construct full matrix for CUDA
116121
= _create_blockdiag(UP, UM, σP, σMs, n_batch)
117122
ζ_resid =' * urand
118-
logdetΣ = 2 .* sum(log.(diag(Uσ)))
123+
σ = diag(Uσ) # elements of the diagonal: standard deviations
119124
# returns CuArrays to either continue on GPU or need to transfer to CPU
120-
ζ_resid, logdetΣ
125+
ζ_resid, σ
121126
end
122127

123128
function _create_blockdiag(UP::AbstractMatrix{T}, UM, σP, σMs, n_batch) where {T}

test/test_cholesky_structure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ end
241241
optprob = Optimization.OptimizationProblem(optf, Us1vec0)
242242
res = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.02),
243243
#callback=callback_loss(50),
244-
maxiters = 800)
244+
maxiters = 1_000)
245245

246246
Upred = CP.transformU_cholesky1(res.u; n = n_U)
247247
#@test Upred ≈ CU

test/test_elbo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ n_MC = 3
9898
end
9999

100100
@testset "generate_ζ" begin
101-
ζ, logdetΣ = CP.generate_ζ(
101+
ζ, σ = CP.generate_ζ(
102102
rng, g, f, ϕ_ini, xM[:, 1:n_batch], map(get_concrete, interpreters);
103103
n_MC = 8)
104104
@test ζ isa Matrix
@@ -119,7 +119,7 @@ if CUDA.functional()
119119
@testset "generate_ζ gpu" begin
120120
ϕ = CuArray(CA.getdata(ϕ_ini))
121121
xMg_batch = CuArray(xM[:, 1:n_batch])
122-
ζ, logdetΣ = CP.generate_ζ(
122+
ζ, σ = CP.generate_ζ(
123123
rng, g_flux, f, ϕ, xMg_batch, map(get_concrete, interpreters);
124124
n_MC = 8)
125125
@test ζ isa CuMatrix

test/test_logden_normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end;
3333
@testset "entropy_MvNormal" begin
3434
S = Diagonal([4,5]) .+ rand(2,2)
3535
S2 = Symmetric(S*S)
36-
@test entropy_MvNormal(S2) == entropy(MvNormal(S2))
36+
@test entropy_MvNormal(S2) entropy(MvNormal(S2))
3737
end;
3838

3939

test/test_sample_zeta.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ scenario = (:default,)
5959
@testset "sample_ζ_norm0 cpu" begin
6060
ϕ = CA.getdata(ϕ_cpu)
6161
ϕc = interpreters.pmu(ϕ)
62-
ζ_resid, logdetΣ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
62+
ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
6363
@test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC)
6464
gr = Zygote.gradient(ϕc -> sum(CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc)[1]), ϕc)[1]
6565
@test length(gr) == length(ϕ)
@@ -76,9 +76,9 @@ scenario = (:default,)
7676
#ζP, ζMs, ϕunc = ϕc.P, ϕc.Ms, ϕc.unc
7777
#urand = CUDA.randn(length(ϕc.P) + length(ϕc.Ms), n_MC) |> gpu
7878
#include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
79-
#ζ_resid, logdetΣ = sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
79+
#ζ_resid, σ = sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
8080
#Zygote.gradient(ϕc -> sum(sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1];
81-
ζ_resid, logdetΣ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
81+
ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)
8282
@test ζ_resid isa GPUArraysCore.AbstractGPUArray
8383
@test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC)
8484
gr = Zygote.gradient(

0 commit comments

Comments
 (0)