@@ -20,17 +20,22 @@ expected value of the likelihood of observations.
2020"""
2121function neg_elbo_transnorm_gf (rng, g, f, ϕ:: AbstractVector , y_ob, x:: AbstractMatrix ,
2222 transPMs, interpreters:: NamedTuple ;
23- n_MC= 3 , logσ2y, gpu_data_handler = get_default_GPUHandler ())
24- ζs, logdetΣ = generate_ζ (rng, g, f, ϕ, x, interpreters; n_MC)
23+ n_MC= 3 , logσ2y, gpu_data_handler = get_default_GPUHandler (),
24+ entropyN = 0.0 ,
25+ )
26+ ζs, σ = generate_ζ (rng, g, f, ϕ, x, interpreters; n_MC)
2527 ζs_cpu = gpu_data_handler (ζs) # differentiable fetch to CPU in Flux package extension
2628 # ζi = first(eachcol(ζs_cpu))
2729 nLy = reduce (+ , map (eachcol (ζs_cpu)) do ζi
2830 y_pred_i, logjac = predict_y (ζi, f, transPMs)
2931 nLy1 = neg_logden_indep_normal (y_ob, y_pred_i, logσ2y)
3032 nLy1 - logjac
3133 end ) / n_MC
32- ent = entropy_MvNormal (size (ζs, 1 ), logdetΣ) # defined in logden_normal
33- nLy - ent
34+ logdet_jacT2 = sum_log_σ = sum (log .(σ))
35+ # logdet_jacT2 = -sum_log_σ # log Prod(1/σ_i) = -sum log σ_i
36+ # logdetΣ = 2 * sum_log_σ # log Prod(σ_i²) = 2* sum log σ_i
37+ # ent = entropy_MvNormal(size(ζs, 1), logdetΣ) # defined in logden_normal
38+ nLy - logdet_jacT2 - entropyN
3439end
3540
3641"""
@@ -45,17 +50,17 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret
4550 gpu_data_handler= get_default_GPUHandler ())
4651 n_site = size (xM, 2 )
4752 intm_PMs_gen = get_ca_int_PMs (n_site)
48- tans_PMs_gen = get_transPMs (n_site)
53+ trans_PMs_gen = get_transPMs (n_site)
4954 ζs, _ = generate_ζ (rng, g, f, CA. getdata (ϕ), CA. getdata (xM),
5055 (; interpreters... , PMs = intm_PMs_gen); n_MC = n_sample_pred)
5156 ζs_cpu = gpu_data_handler (ζs) #
52- y_pred = stack (map (ζ -> first (predict_y (ζ, f, tans_PMs_gen )), eachcol (ζs_cpu)));
57+ y_pred = stack (map (ζ -> first (predict_y (ζ, f, trans_PMs_gen )), eachcol (ζs_cpu)));
5358 y_pred
5459end
5560
5661"""
57- Generate samples of (inv-transformed) model parameters, ζ, and Log-Determinant
58- of their distribution .
62+ Generate samples of (inv-transformed) model parameters, ζ,
63+ and the vector of standard deviations, σ, i.e. the diagonal of the cholesky-factor .
5964
6065Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0`
6166to the means extracted from parameters and predicted by the machine learning
@@ -68,21 +73,21 @@ function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix,
6873 μ_ζP = ϕc. μP
6974 ϕg = ϕc. ϕg
7075 μ_ζMs0 = g (x, ϕg) # TODO provide μ_ζP to g
71- ζ_resid, logdetΣ = sample_ζ_norm0 (rng, μ_ζP, μ_ζMs0, ϕc. unc; n_MC)
72- # ζ_resid, logdetΣ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
76+ ζ_resid, σ = sample_ζ_norm0 (rng, μ_ζP, μ_ζMs0, ϕc. unc; n_MC)
77+ # ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
7378 ζ = stack (map (eachcol (ζ_resid)) do r
7479 rc = interpreters. PMs (r)
7580 ζP = μ_ζP .+ rc. θP
7681 μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g
7782 ζMs = μ_ζMs .+ rc. θMs
7883 vcat (ζP, vec (ζMs))
7984 end )
80- ζ, logdetΣ
85+ ζ, σ
8186end
8287
8388"""
8489Extract relevant parameters from θ and return n_MC generated draws
85- together with the logdet of the transformation .
90+ together with the vector of standard deviations, σ .
8691
8792Necessary typestable information on number of compponents are provided with
8893ComponentMarshellers
@@ -115,9 +120,9 @@ function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::Ab
115120 # need to construct full matrix for CUDA
116121 Uσ = _create_blockdiag (UP, UM, σP, σMs, n_batch)
117122 ζ_resid = Uσ' * urand
118- logdetΣ = 2 .* sum ( log .( diag (Uσ)))
123+ σ = diag (Uσ) # elements of the diagonal: standard deviations
119124 # returns CuArrays to either continue on GPU or need to transfer to CPU
120- ζ_resid, logdetΣ
125+ ζ_resid, σ
121126end
122127
123128function _create_blockdiag (UP:: AbstractMatrix{T} , UM, σP, σMs, n_batch) where {T}
0 commit comments