Skip to content

Commit 7e1527e

Browse files
authored
Merge pull request #23 from EarthyScience/dev
Use Callable Process-Base Model rather than closure to support JLD loading and fix test on correlations
2 parents 5b9565f + 80d0f4d commit 7e1527e

File tree

8 files changed

+122
-73
lines changed

8 files changed

+122
-73
lines changed

dev/doubleMM.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,13 @@ prob2o = probo;
197197
fname_probos = "intermediate/probos800_$(last(HVI._val_value(scenario))).jld2"
198198
JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o))
199199
tmp = JLD2.load(fname_probos)
200-
# TODO replace function closure by Callable to store
201-
# closure function could not be restored with JLD2
202-
prob1o = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
203-
prob2o = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
204200
end
205201

206202
() -> begin # load the non-covar scenario
207203
using JLD2
208204
#fname_probos = "intermediate/probos_$(last(_val_value(scenario))).jld2"
209205
fname_probos = "intermediate/probos800_omit_r0.jld2"
210206
tmp = JLD2.load(fname_probos)
211-
# get_train_loader function could not be restored with JLD2
212-
prob1o_indep = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
213-
prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
214207
# test predicting correct obs-uncertainty of predictive posterior
215208
n_sample_pred = 400
216209
(; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);

src/AbstractHybridProblem.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ function gen_hybridproblem_synthetic end
152152
153153
Determine the FloatType for given Case and scenario, defaults to Float32
154154
"""
155-
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario = ())
155+
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario)
156156
return eltype(get_hybridproblem_par_templates(prob; scenario).θM)
157157
end
158158

@@ -263,4 +263,52 @@ function setup_PBMpar_interpreter(θP, θM, θall = vcat(θP, θM))
263263
θFix = θall[keys_fixed]
264264
intθ = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
265265
intθ, θFix
266-
end
266+
end
267+
268+
struct PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}
269+
θFix::θFixT
270+
θFix_dev::θFix_devT
271+
intθ::StaticComponentArrayInterpreter{AX}
272+
isP::Matrix{Int}
273+
n_site_batch::Int
274+
pos_xP::pos_xPT
275+
end
276+
277+
function PBmodelClosure(prob::AbstractHybridProblem; scenario::Val{scen},
278+
use_all_sites = false,
279+
gdev = :f_on_gpu _val_value(scenario) ? gpu_device() : identity,
280+
θall, int_xP1,
281+
) where {scen}
282+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
283+
n_site_batch = use_all_sites ? n_site : n_batch
284+
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
285+
par_templates = get_hybridproblem_par_templates(prob; scenario)
286+
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
287+
θFix = repeat(θFix1', n_site_batch)
288+
θFix_dev = gdev(θFix)
289+
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
290+
#int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
291+
isP = repeat(axes(par_templates.θP, 1)', n_site_batch)
292+
pos_xP = get_positions(int_xP1)
293+
PBmodelClosure(;θFix, θFix_dev, intθ, isP, n_site_batch, pos_xP)
294+
end
295+
296+
function PBmodelClosure(;
297+
θFix::θFixT,
298+
θFix_dev::θFix_devT,
299+
intθ::StaticComponentArrayInterpreter{AX},
300+
isP::Matrix{Int},
301+
n_site_batch::Int,
302+
pos_xP::pos_xPT,
303+
) where {θFixT, θFix_devT, AX, pos_xPT}
304+
PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}(
305+
θFix::AbstractArray, θFix_dev, intθ, isP, n_site_batch, pos_xP)
306+
end
307+
308+
309+
310+
311+
312+
313+
314+

src/DoubleMM/f_doubleMM.jl

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -181,55 +181,43 @@ end
181181
# end
182182
# end
183183

184-
function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::Val{scen},
185-
use_all_sites = false,
186-
gdev = :f_on_gpu HVI._val_value(scenario) ? gpu_device() : identity
187-
) where {scen}
188-
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
189-
n_site_batch = use_all_sites ? n_site : n_batch
190-
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
191-
par_templates = get_hybridproblem_par_templates(prob; scenario)
192-
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
193-
θFix = repeat(θFix1', n_site_batch)
194-
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
195-
#int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
196-
isP = repeat(axes(par_templates.θP, 1)', n_site_batch)
197-
let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP = isP,
198-
n_site_batch = n_site_batch,
199-
#int_xPb=get_concrete(int_xPb),
200-
pos_xP = get_positions(int_xP1)
184+
# defining the PBmodel as a closure with let leads to problems of JLD2 reloading
185+
# Define all the variables additional to the ones passed curing the call by
186+
# a dedicated Closure object and define the PBmodel as a callable
187+
struct DoubleMMCaller{CLT}
188+
cl::CLT
189+
end
201190

202-
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
203-
@assert size(xP, 2) == n_site_batch
204-
@assert size(θMs, 1) == n_site_batch
205-
# # convert vector of tuples to tuple of matricesByRows
206-
# # need to supply xP as vectorOfTuples to work with DataLoader
207-
# # k = first(keys(xP[1]))
208-
# xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
209-
# #stack(map(r -> r[k], xP))'
210-
# stack(map(r -> r[k], xP); dims = 1)
211-
# end)...)
212-
#xPM = map(transpose, xPM1)
213-
#xPc = int_xPb(CA.getdata(xP))
214-
#xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
215-
# make sure the same order of columns as in intθ
216-
# reshape big matrix into NamedTuple of drivers S1 and S2
217-
# for broadcasting need sites in rows
218-
#xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
219-
xPM = map(p -> CA.getdata(xP)'[:, p], pos_xP)
220-
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix
221-
θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs), θFixd)
222-
pred_sites = f_doubleMM(θ, xPM; intθ)'
223-
pred_global = eltype(pred_sites)[]
224-
return pred_global, pred_sites
225-
end
226-
# function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
227-
# # TODO
228-
# pred_sites = f_doubleMM(θMs, θP, θFix, xP, intθ)
229-
# pred_global = eltype(pred_sites)[]
230-
# return pred_global, pred_sites
231-
# end
232-
end
191+
function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario, kwargs...)
192+
# θall defined in this module above
193+
cl = HVI.PBmodelClosure(prob; scenario, θall, int_xP1, kwargs...)
194+
return DoubleMMCaller{typeof(cl)}(cl)
195+
end
196+
197+
function(caller::DoubleMMCaller)(θP::AbstractVector, θMs::AbstractMatrix, xP)
198+
cl = caller.cl
199+
@assert size(xP, 2) == cl.n_site_batch
200+
@assert size(θMs, 1) == cl.n_site_batch
201+
# # convert vector of tuples to tuple of matricesByRows
202+
# # need to supply xP as vectorOfTuples to work with DataLoader
203+
# # k = first(keys(xP[1]))
204+
# xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
205+
# #stack(map(r -> r[k], xP))'
206+
# stack(map(r -> r[k], xP); dims = 1)
207+
# end)...)
208+
#xPM = map(transpose, xPM1)
209+
#xPc = int_xPb(CA.getdata(xP))
210+
#xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
211+
# make sure the same order of columns as in intθ
212+
# reshape big matrix into NamedTuple of drivers S1 and S2
213+
# for broadcasting need sites in rows
214+
#xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
215+
xPM = map(p -> CA.getdata(xP)'[:, p], cl.pos_xP)
216+
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? cl.θFix_dev : cl.θFix
217+
θ = hcat(CA.getdata(θP[cl.isP]), CA.getdata(θMs), θFixd)
218+
pred_sites = f_doubleMM(θ, xPM; cl.intθ)'
219+
pred_global = eltype(pred_sites)[]
220+
return pred_global, pred_sites
233221
end
234222

235223
function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::Val)
@@ -284,8 +272,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
284272
xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site)))
285273
#xP[:S1,:]
286274
θP = par_templates.θP
287-
#θint = ComponentArrayInterpreter( (size(θMs_true,2),), CA.getaxes(vcat(θP, θMs_true[:,1])))
288-
y_global_true, y_true = f(θP, θMs_true', xP)
275+
y_global_true, y_true = f(θP, θMs_true', xP)
289276
σ_o = FloatType(0.01)
290277
#σ_o = FloatType(0.002)
291278
logσ2_o = FloatType(2) .* log.(σ_o)

src/HybridProblem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ struct HybridProblem <: AbstractHybridProblem
2121
θP::CA.ComponentVector, θM::CA.ComponentVector,
2222
g::AbstractModelApplicator, ϕg::AbstractVector,
2323
ϕunc::CA.ComponentVector,
24-
f_batch::Function,
25-
f_allsites::Function,
24+
f_batch,
25+
f_allsites,
2626
priors::AbstractDict,
2727
py,
2828
transM::Stacked,
@@ -43,7 +43,7 @@ end
4343

4444
function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector,
4545
# note no ϕg argument and g_chain unconstrained
46-
g_chain, f_batch::Function,
46+
g_chain, f_batch,
4747
args...; rng = Random.default_rng(), kwargs...)
4848
# dispatches on type of g_chain
4949
g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM))
@@ -74,10 +74,10 @@ function update(prob::HybridProblem;
7474
g::AbstractModelApplicator = prob.g,
7575
ϕg::AbstractVector = prob.ϕg,
7676
ϕunc::CA.ComponentVector = prob.ϕunc,
77-
f_batch::Function = prob.f_batch,
78-
f_allsites::Function = prob.f_allsites,
77+
f_batch = prob.f_batch,
78+
f_allsites = prob.f_allsites,
7979
priors::AbstractDict = prob.priors,
80-
py::Function = prob.py,
80+
py = prob.py,
8181
# transM::Union{Function, Bijectors.Transform} = prob.transM,
8282
# transP::Union{Function, Bijectors.Transform} = prob.transP,
8383
transM = prob.transM,

src/cholesky.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,19 @@ end
176176
# U = _create_blockdiag(v[first(keys(v))], blocks) # v only for dispatch: plain matrix for gpu
177177
# end
178178

179+
"""
180+
Compute the cholesky-factor parameter for a given single
181+
correlation in a 2x2 matrix.
182+
Invert the transformation of cholesky-factor parameterization.
183+
"""
184+
function compute_cholcor_coefficient_single(ρ)
185+
# invert ρ = a / sqrt(a^2 + 1)
186+
sign(ρ) * sqrt^2/(1 - ρ^2))
187+
end
188+
189+
190+
191+
179192
"""
180193
get_ca_starts(vc::ComponentVector)
181194

src/elbo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,8 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
411411
ρsP = isempty(ϕuncc.ρsP) ? similar(ϕuncc.ρsP) : ϕuncc.ρsP # required by zygote
412412
UP = transformU_block_cholesky1(ρsP, cor_ends.P)
413413
ρsM = isempty(ϕuncc.ρsM) ? similar(ϕuncc.ρsM) : ϕuncc.ρsM # required by zygote
414+
# cholesky factor of the correlation: diag(UM' * UM) .== 1
415+
# coefficients ρsM can be larger than 1, still yielding correlations <1 in UM' * UM
414416
UM = transformU_block_cholesky1(ρsM, cor_ends.M)
415417
cf = ϕuncc.coef_logσ2_ζMs
416418
logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs)

test/test_doubleMM.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ prob = DoubleMM.DoubleMMCase()
2525
scenario = Val((:default,))
2626
#using Flux
2727
#scenario = Val((:use_Flux,))
28+
#scenario = Val((:use_Flux,:f_on_gpu))
2829

2930
par_templates = get_hybridproblem_par_templates(prob; scenario)
3031

test/test_elbo.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,13 @@ test_scenario = (scenario) -> begin
130130
ϕunc_true.logσ2_ζP = (log abs2).(sd_ζP_true)
131131
ϕunc_true.coef_logσ2_ζMs[1,:] = (log abs2).(sd_ζMs_a_true)
132132
ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true
133-
ϕunc_true.ρsP = ρsP_true
134-
ϕunc_true.ρsM = ρsM_true
133+
# note that the parameterization contains a transformation that
134+
# here only inverted for the single correlation case
135+
ϕunc_true.ρsP = CP.compute_cholcor_coefficient_single.(ρsP_true)
136+
ϕunc_true.ρsM = CP.compute_cholcor_coefficient_single.(ρsM_true)
137+
# check that ρsM_true = -0.6 recovered with params ϕunc_true.ρsM = -0.75
138+
UC = CP.transformU_cholesky1(ϕunc_true.ρsM); Σ = UC' * UC
139+
@test Σ[1,2] ρsM_true[1]
135140

136141
probd = CP.update(probc; ϕunc=ϕunc_true);
137142
= vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc)
@@ -189,12 +194,12 @@ test_scenario = (scenario) -> begin
189194
residPMst = vcat(residP,
190195
reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3)))
191196
cor_PMs = cor(residPMst')
192-
@test cor_PMs[1,2] ρsP_true[1] atol=0.2
193-
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M
194-
@test cor_PMs[3,4] ρsM_true[1] atol=0.2
195-
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2
196-
@test cor_PMs[5,6] ρsM_true[1] atol=0.2
197-
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2
197+
@test cor_PMs[1,2] ρsP_true[1] atol=0.02
198+
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.02)) # no correlations P,M
199+
@test cor_PMs[3,4] ρsM_true[1] atol=0.02
200+
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.02)) # no correlations M1, M2
201+
@test cor_PMs[5,6] ρsM_true[1] atol=0.02
202+
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.02)) # no correlations M1, M2
198203
end
199204
test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g)
200205
@testset "predict_hvi check sd" begin

0 commit comments

Comments
 (0)