Skip to content

Commit fa44a26

Browse files
committed
provide logdensity_obs computing function with case
and uncertainty with dataloader
1 parent 151551d commit fa44a26

19 files changed

+325
-331
lines changed

dev/doubleMM.jl

Lines changed: 35 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,32 @@ par_templates = get_hybridcase_par_templates(case; scenario)
2626

2727
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
2828

29-
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
29+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
3030
) = gen_hybridcase_synthetic(case, rng; scenario);
3131

3232
#----- fit g to θMs_true
3333
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
34+
(; transP, transM) = get_hybridcase_transforms(case; scenario)
3435

35-
function loss_g(ϕg, x, g)
36+
function loss_g(ϕg, x, g, transM)
3637
ζMs = g(x, ϕg) # predict the log of the parameters
37-
θMs = exp.(ζMs)
38+
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
3839
loss = sum(abs2, θMs .- θMs_true)
3940
return loss, θMs
4041
end
41-
loss_g(ϕg0, xM, g)
42-
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
42+
loss_g(ϕg0, xM, g, transM)
4343

44-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
44+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
4545
Optimization.AutoZygote())
4646
optprob = Optimization.OptimizationProblem(optf, ϕg0);
4747
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
4848

4949
ϕg_opt1 = res.u;
50-
loss_g(ϕg_opt1, xM, g)
51-
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
52-
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
50+
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
51+
scatterplot(vec(θMs_true), vec(θMs_pred))
5352

5453
f = get_hybridcase_PBmodel(case; scenario)
54+
py = get_hybridcase_neg_logden_obs(case; scenario)
5555

5656
#----------- fit g and θP to y_o
5757
() -> begin
@@ -82,13 +82,12 @@ f = get_hybridcase_PBmodel(case; scenario)
8282
end
8383

8484
#---------- HVI
85-
logσ2y = 2 .* log.(σ_o)
8685
n_MC = 3
87-
transP = elementwise(exp)
88-
transM = Stacked(elementwise(identity), elementwise(exp))
86+
(; transP, transM) = get_hybridcase_transforms(case; scenario)
87+
FT = get_hybridcase_float_type(case; scenario)
8988

9089
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
91-
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
90+
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
9291
ϕ_true = ϕ
9392

9493
() -> begin
@@ -149,49 +148,21 @@ transM = Stacked(elementwise(identity), elementwise(exp))
149148
ϕ_true = inverse_ca(trans_gu, ϕt_true)
150149
end
151150

152-
ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕ_true[[:unc]]); # scratch
151+
ϕ_ini0 = ζ = reduce(
152+
vcat, (
153+
ϕ_true[[:μP]] .* FT(0.001), CA.ComponentVector(ϕg = ϕg0), ϕ_true[[:unc]])) # scratch
153154
#
154-
# true values
155-
ϕ_ini = ζ = vcat(ϕ_true[[:μP, :ϕg]] .* 1.2, ϕ_true[[:unc]]); # slight disturbance
155+
ϕ_ini = ζ = reduce(
156+
vcat, (
157+
ϕ_true[[:μP]] .- FT(0.1), ϕ_true[[:ϕg]] .* FT(1.1), ϕ_true[[:unc]])) # slight disturbance
156158
# hardcoded from HMC inversion
157159
ϕ_ini.unc.coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
158160
ϕ_ini.unc.logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
159161
mean_σ_o_MC = 0.006042
160162

161-
# test cost function and gradient
162-
() -> begin
163-
neg_elbo_transnorm_gf(rng, g, f, ϕ_true, y_o[:, 1:n_batch], xM[:, 1:n_batch],
164-
transPMs_batch, map(get_concrete, interpreters);
165-
n_MC = 8, logσ2y)
166-
Zygote.gradient(
167-
ϕ -> neg_elbo_transnorm_gf(
168-
rng, g, f, ϕ, y_o[:, 1:n_batch], xM[:, 1:n_batch],
169-
transPMs_batch, interpreters; n_MC = 8, logσ2y),
170-
CA.getdata(ϕ_true))
171-
end
172-
173-
# optimize using SimpleChains
174-
() -> begin
175-
train_loader = MLUtils.DataLoader((xM, y_o), batchsize = n_batch)
176-
177-
optf = Optimization.OptimizationFunction(
178-
(ϕ, data) -> begin
179-
xM, y_o = data
180-
neg_elbo_transnorm_gf(
181-
rng, g, f, ϕ, y_o, xM, transPMs_batch,
182-
map(get_concrete, interpreters_g); n_MC = 5, logσ2y)
183-
end,
184-
Optimization.AutoZygote())
185-
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader)
186-
res = Optimization.solve(
187-
optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 800)
188-
#optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
189-
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
190-
end
191-
192-
ϕ = ϕ_ini |> Flux.gpu;
163+
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
193164
xM_gpu = xM |> Flux.gpu;
194-
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
165+
g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
195166

196167
# otpimize using LUX
197168
() -> begin
@@ -216,27 +187,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario
216187
g_flux = g_luxs
217188
end
218189

219-
function fcost(ϕ, xM, y_o)
220-
neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o,
221-
xM, transPMs_batch, map(get_concrete, interpreters);
222-
n_MC = 8, logσ2y = logσ2y)
190+
function fcost(ϕ, xM, y_o, y_unc)
191+
neg_elbo_transnorm_gf(rng, g_flux, f, py, CA.getdata(ϕ), y_o, y_unc,
192+
xM, xP, transPMs_batch, map(get_concrete, interpreters);
193+
n_MC = 8)
223194
end
224-
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch])
195+
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch])
225196
#Zygote.gradient(fcost, ϕ) |> cpu;
226197
gr = Zygote.gradient(fcost,
227-
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
228-
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
198+
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]),
199+
CA.getdata(y_o[:, 1:n_batch]), CA.getdata(y_unc[:, 1:n_batch]));
200+
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)
229201

230-
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
231-
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
202+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
203+
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
232204

233205
optf = Optimization.OptimizationFunction(
234206
(ϕ, data) -> begin
235-
xM, y_o = data
236-
fcost(ϕ, xM, y_o)
237-
# neg_elbo_transnorm_gf(
238-
# rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
239-
# map(get_concrete, interpreters); n_MC = 5, logσ2y)
207+
xM, xP, y_o, y_unc = data
208+
fcost(ϕ, xM, y_o, y_unc)
240209
end,
241210
Optimization.AutoZygote())
242211
optprob = Optimization.OptimizationProblem(
@@ -256,7 +225,7 @@ end
256225
ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu
257226
ϕunc_VI = interpreters.unc(ζ_VIc.unc)
258227

259-
hcat(θP_true, exp.(ζ_VIc.μP))
228+
hcat(log.(θP_true), ϕ_ini.μP, ζ_VIc.μP)
260229
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
261230
#lineplot!(plt, 0.0, 1.1, identity)
262231
#
@@ -266,11 +235,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
266235
# test predicting correct obs-uncertainty of predictive posterior
267236
n_sample_pred = 200
268237

269-
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters;
238+
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
270239
get_transPMs, get_ca_int_PMs, n_sample_pred);
271240
size(y_pred) # n_obs x n_site, n_sample_pred
272241

273242
σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
243+
σ_o = exp.(y_unc[:,1] / 2)
274244

275245
#describe(σ_o_post)
276246
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ module HybridVariationalInferenceFluxExt
33
using HybridVariationalInference, Flux
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using ComponentArrays: ComponentArrays as CA
6+
using Random
67

78
struct FluxApplicator{RT} <: AbstractModelApplicator
89
rebuild::RT
910
end
1011

11-
function HVI.construct_FluxApplicator(m::Chain)
12+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
13+
# TODO: care fore rng and float_type
1214
ϕ, rebuild = destructure(m)
1315
FluxApplicator(rebuild), ϕ
1416
end
@@ -26,17 +28,17 @@ function __init__()
2628
HVI.set_default_GPUHandler(FluxGPUDataHandler())
2729
end
2830

29-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
30-
args...; kwargs...)
31-
# constructor with Flux.Chain
32-
g, ϕg = construct_FluxApplicator(g_chain)
33-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
34-
end
31+
# function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
32+
# args...; kwargs...)
33+
# # constructor with Flux.Chain
34+
# g, ϕg = construct_FluxApplicator(g_chain)
35+
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
36+
# end
3537

36-
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
38+
function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
3739
scenario::NTuple = ())
3840
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
39-
FloatType = get_hybridcase_FloatType(case; scenario)
41+
float_type = get_hybridcase_float_type(case; scenario)
4042
n_out = n_θM
4143
is_using_dropout = :use_dropout scenario
4244
is_using_dropout && error("dropout scenario not supported with Flux yet.")
@@ -47,7 +49,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
4749
# dense layer without bias that maps to n outputs and `identity` activation
4850
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
4951
)
50-
construct_FluxApplicator(g_chain)
52+
construct_ChainsApplicator(rng, g_chain, float_type)
5153
end
5254

5355

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator
1010
int_ϕ::IT
1111
end
1212

13-
function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
14-
ps, st = Lux.setup(Random.default_rng(), m)
13+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type=Float32; device = gpu_device())
14+
ps, st = Lux.setup(rng, m)
1515
ps_ca = float_type.(CA.ComponentArray(ps))
1616
st = st |> device
1717
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
@@ -25,11 +25,12 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ)
2525
app.stateful_layer(x, ϕc)
2626
end
2727

28-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
29-
args...; device = gpu_device(), kwargs...)
30-
# constructor with SimpleChain
31-
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
32-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
33-
end
28+
# function HVI.HybridProblem(rng::AbstractRNG,
29+
# θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
30+
# args...; device = gpu_device(), kwargs...)
31+
# # constructor with SimpleChain
32+
# g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM); device)
33+
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
34+
# end
3435

3536
end # module

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,25 @@ using HybridVariationalInference, SimpleChains
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using StatsFuns: logistic
66
using ComponentArrays: ComponentArrays as CA
7+
using Random
78

89

910

1011
struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
1112
m::MT
1213
end
1314

14-
function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
15-
ϕ = SimpleChains.init_params(m, FloatType);
15+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::SimpleChain, FloatType=Float32)
16+
ϕ = SimpleChains.init_params(m, FloatType; rng);
1617
SimpleChainsApplicator(m), ϕ
1718
end
1819

1920
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
2021

21-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
22-
args...; kwargs...)
23-
# constructor with SimpleChain
24-
g, ϕg = construct_SimpleChainsApplicator(g_chain)
25-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
26-
end
27-
28-
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
22+
function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
2923
scenario::NTuple=())
3024
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
31-
FloatType = get_hybridcase_FloatType(case; scenario)
25+
FloatType = get_hybridcase_float_type(case; scenario)
3226
n_out = n_θM
3327
is_using_dropout = :use_dropout scenario
3428
g_chain = if is_using_dropout
@@ -52,7 +46,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
5246
TurboDense{false}(identity, n_out)
5347
)
5448
end
55-
construct_SimpleChainsApplicator(g_chain, FloatType)
49+
construct_ChainsApplicator(rng, g_chain, FloatType)
5650
end
5751

5852
end # module

src/DoubleMM/f_doubleMM.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
2222
(; θP, θM)
2323
end
2424

25-
function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ())
25+
function HVI.get_hybridcase_transforms(::DoubleMMCase; scenario::NTuple = ())
2626
(; transP, transM)
2727
end
2828

29+
function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())
30+
neg_logden_indep_normal
31+
end
32+
2933
function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
3034
n_covar_pc = 2
3135
n_covar = n_covar_pc + 3 # linear dependent
@@ -46,7 +50,7 @@ function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
4650
end
4751
end
4852

49-
# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
53+
# function HVI.get_hybridcase_float_type(::DoubleMMCase; scenario)
5054
# return Float32
5155
# end
5256

@@ -58,7 +62,7 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
5862
n_covar_pc = 2
5963
n_site = 200
6064
(; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
61-
FloatType = get_hybridcase_FloatType(case; scenario)
65+
FloatType = get_hybridcase_float_type(case; scenario)
6266
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
6367
rhodec = 8, is_using_dropout = false)
6468
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
@@ -68,6 +72,7 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
6872
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
6973
y_global_true, y_true = f(θP, θMs_true, xP)
7074
σ_o = FloatType(0.01)
75+
logσ2_o = FloatType(2) .* log.(σ_o)
7176
#σ_o = 0.002
7277
y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o
7378
y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o
@@ -81,9 +86,11 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
8186
y_true,
8287
y_global_o,
8388
y_o,
84-
σ_o = fill(σ_o, size(y_true,1)),
89+
y_unc = fill(logσ2_o, size(y_o)),
8590
)
8691
end
8792

8893

8994

95+
96+

0 commit comments

Comments
 (0)