Skip to content

Commit a76b5eb

Browse files
committed
test distribution oof generated residuals
1 parent e49dcb9 commit a76b5eb

File tree

7 files changed

+173
-46
lines changed

7 files changed

+173
-46
lines changed

dev/doubleMM.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ n_epoch = 40
160160
# update the problem with optimized parameters, including uncertainty
161161
prob1o = probo;
162162
n_sample_pred = 400
163-
#(; θ, y) = predict_gf(rng, prob1o, xM, xP; scenario, n_sample_pred);
164-
(; θ, y) = predict_gf(rng, prob1o; scenario, n_sample_pred);
163+
#(; θ, y) = predict_hvi(rng, prob1o, xM, xP; scenario, n_sample_pred);
164+
(; θ, y) = predict_hvi(rng, prob1o; scenario, n_sample_pred);
165165
(θ1, y1) = (θ, y);
166166

167167
() -> begin # prediction with fitted parameters (should be smaller than mean)
@@ -210,7 +210,7 @@ end
210210
prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
211211
# test predicting correct obs-uncertainty of predictive posterior
212212
n_sample_pred = 400
213-
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
213+
(; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
214214
(θ2_indep, y2_indep) = (θ, y)
215215
#(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed
216216
end
@@ -241,12 +241,12 @@ end
241241
#ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu
242242
ϕunc_VI = interpreters.unc(ζ_VIc.unc)
243243
ϕunc_VI.ρsM
244-
exp.(ϕunc_VI.logσ2_logP)
244+
exp.(ϕunc_VI.logσ2_ζP)
245245
exp.(ϕunc_VI.coef_logσ2_ζMs[1, :])
246246

247247
# test predicting correct obs-uncertainty of predictive posterior
248248
n_sample_pred = 400
249-
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred);
249+
(; θ, y, entropy_ζ) = predict_hvi(rng, prob2o; scenario, n_sample_pred);
250250
(θ2, y2) = (θ, y)
251251
size(y) # n_obs x n_site, n_sample_pred
252252
size(θ) # n_θP + n_site * n_θM x n_sample
@@ -320,7 +320,7 @@ end
320320

321321
() -> begin # look at distribution of parameters, predictions, and likelihood and elob at one site
322322
function predict_site(probo, i_site)
323-
(; θ, y, entropy_ζ) = predict_gf(rng, probo, xM, xP; scenario, n_sample_pred)
323+
(; θ, y, entropy_ζ) = predict_hvi(rng, probo, xM, xP; scenario, n_sample_pred)
324324
y_site = y[:, i_site, :]
325325
θMs_i = map(i_rep -> θ[:Ms, i_rep][:, i_site], axes(θ, 2))
326326
r1s = map(x -> x[1], θMs_i)

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ include("logden_normal.jl")
8484
export get_ca_starts, get_ca_ends, get_cor_count
8585
include("cholesky.jl")
8686

87-
export neg_elbo_gtf, predict_gf
87+
export neg_elbo_gtf, predict_hvi
8888
include("elbo.jl")
8989

9090
export init_hybrid_params, init_hybrid_ϕunc

src/elbo.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,40 @@ end
141141
end
142142

143143
"""
144-
predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters;
144+
predict_hvi([rng], prob::AbstractHybridProblem [,xM, xP]; scenario, ...)
145+
predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix;
145146
get_transPMs, get_ca_int_PMs, n_sample_pred=200, gdev = identity)
146147
147-
Prediction function for hybrid model. Returns an NamedTuple with entries
148-
- `θ`: ComponentArray `(n_θP + n_site * n_θM), n_sample_pred)` of PBM model parameters.
148+
Prediction function for hybrid variational inference parameter model.
149+
150+
## Arguments
151+
- The problem for which to predict
152+
- xM: covariates for the machine-learning model (ML): Matrix (n_θM x n_site_pred).
153+
- xP: model drivers for process based model (PBM): Matrix with (n_site_pred) rows.
154+
If provided a ComponentArray with a Tuple-Axis in rows, the PBM model can
155+
access parts of it, e.g. `xP[:S1,...]`.
156+
157+
## Keyword arguments
158+
- scenario
159+
- n_sample_pred
160+
161+
Returns an NamedTuple `(; y, θsP, θsMs, entropy_ζ)` with entries
149162
- `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions.
163+
- `θsP`: ComponentArray `(n_θP, n_sample_pred)` of PBM model parameters
164+
that are kept constant across sites.
165+
- `θsMs`: ComponentArray `(n_site, n_θM, n_sample_pred)` of PBM model parameters
166+
that vary by site.
167+
- `entropy_ζ`: The entroy of the log-determinant of the transformation of
168+
the set of model parameters, which is involved in uncertainty quantification.
150169
"""
151-
function predict_gf(rng, prob::AbstractHybridProblem; scenario, kwargs...)
170+
function predict_hvi(rng, prob::AbstractHybridProblem; scenario, kwargs...)
152171
dl = get_hybridproblem_train_dataloader(prob; scenario)
153172
dl_dev = gdev_hybridproblem_dataloader(dl; scenario)
154173
# predict for all sites
155174
xM, xP = dl_dev.data[1:2]
156-
predict_gf(rng, prob, xM, xP; scenario, kwargs...)
175+
predict_hvi(rng, prob, xM, xP; scenario, kwargs...)
157176
end
158-
function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
177+
function predict_hvi(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
159178
scenario,
160179
n_sample_pred=200,
161180
gdev=:use_gpu _val_value(scenario) ? gpu_device() : identity,
@@ -184,12 +203,12 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP;
184203
int_unc = interpreters.unc
185204
transMs = StackedArray(transM, n_batch)
186205
g_dev, ϕ_dev = gdev(g), gdev(ϕ)
187-
predict_gf(rng, g_dev, f, ϕ_dev, xM, xP;
206+
predict_hvi(rng, g_dev, f, ϕ_dev, xM, xP;
188207
int_μP_ϕg_unc, int_unc, transP, transM,
189208
n_sample_pred, cdev, cor_ends, pbm_covar_indices, kwargs...)
190209
end
191210

192-
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP;
211+
function predict_hvi(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP;
193212
int_μP_ϕg_unc::AbstractComponentArrayInterpreter,
194213
int_unc::AbstractComponentArrayInterpreter,
195214
transP, transM,
@@ -274,6 +293,8 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
274293
# first pass: append μ_ζP_to covars, need ML prediction for magnitude of ζMs
275294
# TODO replace pbm_covar_indices by ComponentArray? dimensions to be type-inferred?
276295
xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices)
296+
#Main.@infiltrate_main
297+
277298
μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type
278299
ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc)
279300
if pbm_covar_indices isa SA.SVector{0}
@@ -362,7 +383,7 @@ ML-model predcitions of size `(n_θM, n_site)`.
362383
363384
## Arguments
364385
* `int_unc`: Interpret vector as ComponentVector with components
365-
ρsP, ρsM, logσ2_logP, coef_logσ2_ζMs(intercept + slope),
386+
ρsP, ρsM, logσ2_ζP, coef_logσ2_ζMs(intercept + slope),
366387
"""
367388
function sample_ζresid_norm(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix,
368389
args...; n_MC, cor_ends, int_unc)
@@ -392,10 +413,10 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
392413
UM = transformU_block_cholesky1(ρsM, cor_ends.M)
393414
cf = ϕuncc.coef_logσ2_ζMs
394415
logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs)
395-
logσ2_logP = vec(CA.getdata(ϕuncc.logσ2_logP))
416+
logσ2_ζP = vec(CA.getdata(ϕuncc.logσ2_ζP))
396417
# CUDA cannot multiply BlockDiagonal * Diagonal, construct already those blocks
397418
σMs = reshape(exp.(logσ2_logMs ./ 2), n_θM, :)
398-
σP = exp.(logσ2_logP ./ 2)
419+
σP = exp.(logσ2_ζP ./ 2)
399420
# BlockDiagonal does work with CUDA, but not with combination of Zygote and CUDA
400421
# need to construct full matrix for CUDA
401422
= _create_blockdiag(UP, UM, σP, σMs, n_batch)

src/init_hybrid_params.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function init_hybrid_params(θP::AbstractVector{FT}, θM::AbstractVector{FT},
7272
end
7373

7474
"""
75-
init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_logP, coef_logσ2_ζMs, ρsP, ρsM)
75+
init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_ζP, coef_logσ2_ζMs, ρsP, ρsM)
7676
7777
Initialize vector of additional parameter of the approximate posterior.
7878
@@ -83,7 +83,7 @@ Arguments:
8383
- `coef_logσ2_logM`: default column for `coef_logσ2_ζMs`, defaults to `[-10.0, 0.0]`
8484
8585
Returns a `ComponentVector` of
86-
- `logσ2_logP`: vector of log-variances of ζP (on log scale).
86+
- `logσ2_ζP`: vector of log-variances of ζP (on log scale).
8787
defaults to -10
8888
- `coef_logσ2_ζMs`: offset and slope for the log-variances of ζM scaling with
8989
its value given by columns for each parameter in ζM, defaults to `[-10, 0]`
@@ -94,14 +94,14 @@ function init_hybrid_ϕunc(
9494
cor_ends::NamedTuple,
9595
ρ0::FT = 0.0f0,
9696
coef_logσ2_logM::AbstractVector{FT} = FT[-10.0, 0.0];
97-
logσ2_logP::AbstractVector{FT} = fill(FT(-10.0), cor_ends.P[end]),
97+
logσ2_ζP::AbstractVector{FT} = fill(FT(-10.0), cor_ends.P[end]),
9898
coef_logσ2_ζMs::AbstractMatrix{FT} = reduce(
9999
hcat, (coef_logσ2_logM for _ in 1:cor_ends.M[end])),
100100
ρsP = fill(ρ0, get_cor_count(cor_ends.P)),
101101
ρsM = fill(ρ0, get_cor_count(cor_ends.M)),
102102
) where {FT}
103103
nt = (;
104-
logσ2_logP,
104+
logσ2_ζP,
105105
coef_logσ2_ζMs,
106106
ρsP,
107107
ρsM)

test/test_HybridProblem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ test_with_flux = (scenario) -> begin
263263
@test cdev(ϕ.unc.ρsM)[1] > 0
264264
@test probo.ϕunc == cdev(ϕ.unc)
265265
n_sample_pred = 22
266-
(; y, θsP, θsMs) = predict_gf(
266+
(; y, θsP, θsMs) = predict_hvi(
267267
rng, probo; scenario = scenf, n_sample_pred, is_inferred=Val(true));
268268
(_xM, _xP, _y_o, _y_unc, _i_sites) = get_hybridproblem_train_dataloader(prob; scenario).data
269269
@test size(y) == (size(_y_o)..., n_sample_pred)
@@ -279,7 +279,7 @@ test_with_flux = (scenario) -> begin
279279
@test probo.ϕunc == cdev(ϕ.unc)
280280
# predict using problem and its associated dataloader
281281
n_sample_pred = 201
282-
(; y, θsP, θsMs) = predict_gf(rng, probo; scenario = scenf, n_sample_pred);
282+
(; y, θsP, θsMs) = predict_hvi(rng, probo; scenario = scenf, n_sample_pred);
283283
# to inspect correlations among θP and θMs construct ComponentVector
284284
hpints = HybridProblemInterpreters(prob; scenario)
285285
int_mPMs = stack_ca_int(Val((n_sample_pred,)), get_int_PMst_site(hpints))
@@ -317,7 +317,7 @@ test_with_flux = (scenario) -> begin
317317
);
318318
@test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector
319319
n_sample_pred = 11
320-
(; y, θsP, θsMs) = predict_gf(
320+
(; y, θsP, θsMs) = predict_hvi(
321321
rng, probo; scenario = scenf, n_sample_pred,is_inferred = Val(true));
322322
# @test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations
323323
end;

0 commit comments

Comments
 (0)