@@ -22,9 +22,9 @@ expected value of the likelihood of observations.
2222function neg_elbo_transnorm_gf (rng, g, f, ϕ:: AbstractVector , y_ob, xM:: AbstractMatrix ,
2323 xP, transPMs, interpreters:: NamedTuple ;
2424 n_MC= 3 , logσ2y, gpu_data_handler = get_default_GPUHandler (),
25- entropyN = 0.0 ,
25+ cor_starts = (P = ( 1 ,),M = ( 1 ,))
2626 )
27- ζs, σ = generate_ζ (rng, g, f, ϕ, xM, interpreters; n_MC)
27+ ζs, σ = generate_ζ (rng, g, ϕ, xM, interpreters; n_MC, cor_starts )
2828 ζs_cpu = gpu_data_handler (ζs) # differentiable fetch to CPU in Flux package extension
2929 # ζi = first(eachcol(ζs_cpu))
3030 nLy = reduce (+ , map (eachcol (ζs_cpu)) do ζi
@@ -48,13 +48,14 @@ Prediction function for hybrid model. Returns an Array `(n_obs, n_site, n_sample
4848"""
4949function predict_gf (rng, g, f, ϕ:: AbstractVector , xM:: AbstractMatrix , xP, interpreters;
5050 get_transPMs, get_ca_int_PMs, n_sample_pred= 200 ,
51- gpu_data_handler= get_default_GPUHandler ())
51+ gpu_data_handler= get_default_GPUHandler (),
52+ cor_starts= (P= (1 ,),M= (1 ,)))
5253 n_site = size (xM, 2 )
5354 intm_PMs_gen = get_ca_int_PMs (n_site)
5455 trans_PMs_gen = get_transPMs (n_site)
5556 interpreters_gen = (; interpreters... , PMs = intm_PMs_gen)
56- ζs, _ = generate_ζ (rng, g, f, CA. getdata (ϕ), CA. getdata (xM),
57- interpreters_gen; n_MC = n_sample_pred)
57+ ζs, _ = generate_ζ (rng, g, CA. getdata (ϕ), CA. getdata (xM),
58+ interpreters_gen; n_MC = n_sample_pred, cor_starts )
5859 ζs_cpu = gpu_data_handler (ζs) #
5960 y_pred = stack (map (ζ -> first (predict_y (
6061 ζ, xP, f, trans_PMs_gen, interpreters_gen. PMs)), eachcol (ζs_cpu)));
@@ -69,14 +70,14 @@ Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0`
6970to the means extracted from parameters and predicted by the machine learning
7071model.
7172"""
72- function generate_ζ (rng, g, f, ϕ:: AbstractVector , xM:: AbstractMatrix ,
73- interpreters:: NamedTuple ; n_MC= 3 )
73+ function generate_ζ (rng, g, ϕ:: AbstractVector , xM:: AbstractMatrix ,
74+ interpreters:: NamedTuple ; n_MC= 3 , cor_starts = (P = ( 1 ,),M = ( 1 ,)) )
7475 # see documentation of neg_elbo_transnorm_gf
7576 ϕc = interpreters. μP_ϕg_unc (CA. getdata (ϕ))
7677 μ_ζP = ϕc. μP
7778 ϕg = ϕc. ϕg
7879 μ_ζMs0 = g (xM, ϕg) # TODO provide μ_ζP to g
79- ζ_resid, σ = sample_ζ_norm0 (rng, μ_ζP, μ_ζMs0, ϕc. unc; n_MC)
80+ ζ_resid, σ = sample_ζ_norm0 (rng, μ_ζP, μ_ζMs0, ϕc. unc; n_MC, cor_starts )
8081 # ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
8182 ζ = stack (map (eachcol (ζ_resid)) do r
8283 rc = interpreters. PMs (r)
@@ -98,21 +99,21 @@ ComponentMarshellers
9899- marsh_batch(n_batch)
99100- marsh_unc(n_UncP, n_UncM, n_UncCorr)
100101"""
101- function sample_ζ_norm0 (rng:: Random.AbstractRNG , ζP:: AbstractVector , ζMs:: AbstractMatrix , ϕunc :: AbstractVector , args ... ;
102- n_MC= 3 )
102+ function sample_ζ_norm0 (rng:: Random.AbstractRNG , ζP:: AbstractVector , ζMs:: AbstractMatrix ,
103+ args ... ; n_MC, cor_starts)
103104 n_θP, n_θMs = length (ζP), length (ζMs)
104105 urand = _create_random (rng, CA. getdata (ζP), n_θP + n_θMs, n_MC)
105- sample_ζ_norm0 (urand, ζP, ζMs, ϕunc, args... )
106+ sample_ζ_norm0 (urand, ζP, ζMs, args... ; cor_starts )
106107end
107108
108109function sample_ζ_norm0 (urand:: AbstractMatrix , ζP:: AbstractVector{T} , ζMs:: AbstractMatrix ,
109- ϕunc:: AbstractVector , int_unc = ComponentArrayInterpreter (ϕunc);
110+ ϕunc:: AbstractVector , int_unc = ComponentArrayInterpreter (ϕunc); cor_starts
110111 ) where {T}
111112 ϕuncc = int_unc (CA. getdata (ϕunc))
112113 n_θP, n_θMs, (n_θM, n_batch) = length (ζP), length (ζMs), size (ζMs)
113114 # make sure to not create a UpperTriangular Matrix of an CuArray in transformU_cholesky1
114- UP = transformU_cholesky1 (ϕuncc. ρsP)
115- UM = transformU_cholesky1 (ϕuncc. ρsM)
115+ UP = transformU_block_cholesky1 (ϕuncc. ρsP, cor_starts . P )
116+ UM = transformU_block_cholesky1 (ϕuncc. ρsM, cor_starts . M )
116117 cf = ϕuncc. coef_logσ2_logMs
117118 logσ2_logMs = vec (cf[1 , :] .+ cf[2 , :] .* ζMs)
118119 logσ2_logP = vec (CA. getdata (ϕuncc. logσ2_logP))
0 commit comments