Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ authors = ["Thomas Wutzler <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Expand All @@ -21,9 +25,13 @@ HybridVariationalInferenceLuxExt = "Lux"
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"

[compat]
ChainRulesCore = "1.25"
CUDA = "5.5.2"
Combinatorics = "1.0.2"
ComponentArrays = "0.15.19"
Flux = "v0.15.2"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10.0"
Lux = "1.4.2"
Random = "1.10.0"
SimpleChains = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Extending Variational Inference (VI), an approximate bayesian inversion method,
to hybrid models, i.e. models that combine mechanistic and machine-learning parts.

The model inversion, inferes parametric approximations of posterior density
The model inversion, infers parametric approximations of posterior density
of model parameters, by comparing model outputs to uncertain observations. At
the same time, a machine learning model is fit that predicts parameters of these
approximations by covariates.
Expand Down
2 changes: 2 additions & 0 deletions dev/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand All @@ -11,5 +12,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
253 changes: 235 additions & 18 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,245 @@ loss_g(ϕg_opt1, xM, g)
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9

#----------- fit g and θP to y_o
f = gen_hybridcase_PBmodel(case; scenario)
tmpf = () -> begin
#----------- fit g and θP to y_o
# end2end inversion
f = gen_hybridcase_PBmodel(case; scenario)

int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
ϕg = 1:length(ϕg0), θP = par_templates.θP))
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9); # slightly disturb θP_true
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
ϕg = 1:length(ϕg0), θP = par_templates.θP))
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9); # slightly disturb θP_true

# Pass the site-data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
# Pass the site-data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)

loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
l1 = loss_gf(p0, train_loader.data...)[1]
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
l1 = loss_gf(p0, train_loader.data...)[1]

optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
Optimization.AutoZygote())
optprob = OptimizationProblem(optf, p0, train_loader)

res = Optimization.solve(
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);

l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
scatterplot(vec(θMs_true), vec(θMs))
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
scatterplot(vec(y_pred), vec(y_o))
hcat(par_templates.θP, int_ϕθP(res.u).θP)
end

#---------- HADVI
# TODO think about good general initializations
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
logσ2_logP = CA.ComponentVector(r0=-8.997, K2=-5.893)
mean_σ_o_MC = 0.006042

# correlation matrices
ρsP = zeros(sum(1:(n_θP-1)))
ρsM = zeros(sum(1:(n_θM-1)))

ϕunc = CA.ComponentVector(;
logσ2_logP=logσ2_logP,
coef_logσ2_logMs=coef_logσ2_logMs,
ρsP,
ρsM)
int_unc = ComponentArrayInterpreter(ϕunc)

# for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
ϕunc0 = CA.ComponentVector(;
logσ2_logP=fill(-10.0, n_θP),
coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
ρsP,
ρsM)

logσ2y = fill(2 * log(σ_o), size(y_o, 1))
n_MC = 3


#-------------- ADVI with g inside cost function
using CUDA
using TransformVariables

transPMs_batch = as(
(P=as(Array, asℝ₊, n_θP),
Ms=as(Array, asℝ₊, n_θM, n_batch)))
transPMs_all = as(
(P=as(Array, asℝ₊, n_θP),
Ms=as(Array, asℝ₊, n_θM, n_site)))

ϕ_true = θ = CA.ComponentVector(;
μP=θP_true,
ϕg=ϕg_opt,
unc=ϕunc);
trans_gu = as(
(μP=as(Array, asℝ₊, n_θP),
ϕg=as(Array, n_ϕg),
unc=as(Array, length(ϕunc))))
trans_g = as(
(μP=as(Array, asℝ₊, n_θP),
ϕg=as(Array, n_ϕg)))

const int_PMs_batch = ComponentArrayInterpreter(CA.ComponentVector(; θP,
θMs=CA.ComponentMatrix(
zeros(n_θM, n_batch), first(CA.getaxes(θM)), CA.Axis(i=1:n_batch))))

interpreters = interpreters_g = map(get_concrete,(;
μP_ϕg_unc=ComponentArrayInterpreter(ϕ_true),
PMs=int_PMs_batch,
unc=ComponentArrayInterpreter(ϕunc)
))

ϕg_true_vec = CA.ComponentVector(
TransformVariables.inverse(trans_gu, cv2NamedTuple(ϕ_true)))
ϕcg_true = interpreters.μP_ϕg_unc(ϕg_true_vec)
ϕ_ini = ζ = vcat(ϕcg_true[[:μP, :ϕg]] .* 1.2, ϕcg_true[[:unc]]);
ϕ_ini0 = ζ = vcat(ϕcg_true[:μP] .* 0.0, SimpleChains.init_params(g), ϕunc0);

neg_elbo_transnorm_gf(rng, g, f, ϕcg_true, y_o[:, 1:n_batch], x_o[:, 1:n_batch],
transPMs_batch, map(get_concrete, interpreters);
n_MC=8, logσ2y)
Zygote.gradient(ϕ -> neg_elbo_transnorm_gf(
rng, g, f, ϕ, y_o[:, 1:n_batch], x_o[:, 1:n_batch],
transPMs_batch, interpreters; n_MC=8, logσ2y), ϕcg_true)

() -> begin
train_loader = MLUtils.DataLoader((x_o, y_o), batchsize = n_batch)

optf = Optimization.OptimizationFunction((ζg, data) -> begin
x_o, y_o = data
neg_elbo_transnorm_gf(
rng, g, f, ζg, y_o, x_o, transPMs_batch, map(get_concrete, interpreters_g); n_MC=5, logσ2y)
end,
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader);
res = Optimization.solve(optprob, Optimisers.Adam(0.02), callback=callback_loss(50), maxiters=800);
#optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
end

#using Lux
ϕ = ϕcg_true |> gpu;
x_o_gpu = x_o |> gpu;
# y_o = y_o |> gpu
# logσ2y = logσ2y |> gpu
n_covar = size(x_o, 1)
g_flux = Flux.Chain(
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
Flux.Dense(n_covar => n_covar * 4, tanh),
Flux.Dense(n_covar * 4 => n_covar * 4, logistic),
# dense layer without bias that maps to n outputs and `identity` activation
Flux.Dense(n_covar * 4 => n_θM, identity, bias=false),
)
() -> begin
using Lux
g_lux = Lux.Chain(
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
Lux.Dense(n_covar => n_covar * 4, tanh),
Lux.Dense(n_covar * 4 => n_covar * 4, logistic),
# dense layer without bias that maps to n outputs and `identity` activation
Lux.Dense(n_covar * 4 => n_θM, identity, use_bias=false),
)
ps, st = Lux.setup(Random.default_rng(), g_lux)
ps_ca = CA.ComponentArray(ps) |> gpu
st = st |> gpu
g_luxs = StatefulLuxLayer{true}(g_lux, nothing, st)
g_luxs(x_o_gpu[:, 1:n_batch], ps_ca)
ax_g = CA.getaxes(ps_ca)
g_luxs(x_o_gpu[:, 1:n_batch], CA.ComponentArray(ϕ.ϕg, ax_g))
interpreters = (interpreters..., ϕg = ComponentArrayInterpreter(ps_ca))
ϕg = CA.ComponentArray(ϕ.ϕg, ax_g)
ϕgc = interpreters.ϕg(ϕ.ϕg)
g_gpu = g_luxs
end
g_gpu = g_flux

#Zygote.gradient(ϕg -> sum(g_gpu(x_o_gpu[:, 1:n_batch],ϕg)), ϕgc)
# Zygote.gradient(ϕg -> sum(compute_g(g_gpu, x_o_gpu[:, 1:n_batch], ϕg, interpreters)), ϕ.ϕg)
# Zygote.gradient(ϕ -> sum(tmp_gen1(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ.ϕg)
# Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), CA.getdata(ϕ))
# Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
# Zygote.gradient(ϕ -> sum(tmp_gen3(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
# Zygote.gradient(ϕ -> sum(tmp_gen4(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)[1]), ϕ) |> cpu
# generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)
# Zygote.gradient(ϕ -> sum(generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)[1]), ϕ) |> cpu
# include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
# neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
# x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)
# Zygote.gradient(ϕ -> sum(neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
# x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)[1]), ϕ) |> cpu


fcost(ϕ) = neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
x_o_gpu[:, 1:n_batch], transPMs_batch, map(get_concrete, interpreters);
n_MC=8, logσ2y = logσ2y)
fcost(ϕ)
gr = Zygote.gradient(fcost, ϕ) |> cpu;
Zygote.gradient(fcost, CA.getdata(ϕ))


train_loader = MLUtils.DataLoader((x_o_gpu, y_o), batchsize = n_batch)

optf = Optimization.OptimizationFunction((ζg, data) -> begin
x_o, y_o = data
neg_elbo_transnorm_gf(
rng, g_gpu, f, ζg, y_o, x_o, transPMs_batch, map(get_concrete, interpreters_g); n_MC=5, logσ2y)
end,
Optimization.AutoZygote())
optprob = OptimizationProblem(optf, p0, train_loader)
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini) |> gpu, train_loader);
res = res_gpu = Optimization.solve(optprob, Optimisers.Adam(0.02), callback=callback_loss(50), maxiters=800);

ζ_VIc = interpreters_g.μP_ϕg_unc(res.u |> cpu)
ζMs_VI = g(x_o, ζ_VIc.ϕg)
ϕunc_VI = int_unc(ζ_VIc.unc)

hcat(θP_true, exp.(ζ_VIc.μP))
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
#lineplot!(plt, 0.0, 1.1, identity)
#
hcat(ϕunc, ϕunc_VI) # need to compare to MC sample
# hard to estimate for original very small theta's but otherwise good

# test predicting correct obs-uncertainty of predictive posterior
n_sample_pred = 200
intm_PMs_gen = ComponentArrayInterpreter(CA.ComponentVector(; θP,
θMs=CA.ComponentMatrix(
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i=1:n_sample_pred))))

include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
ζs, _ = generate_ζ(rng, g, f, res.u |> cpu, x_o,
(;interpreters..., PMs = intm_PMs_gen); n_MC=n_sample_pred)
# ζ = ζs[:,1]
θsc = stack(ζ -> CA.getdata(CA.ComponentVector(
TransformVariables.transform(transPMs_all, ζ))), eachcol(ζs));
y_pred = stack(map(ζ -> first(predict_y(ζ, f, transPMs_all)), eachcol(ζs)));

size(y_pred)
σ_o_post = mapslices(std, y_pred; dims=3);
#describe(σ_o_post)
vcat(σ_o, mean_σ_o_MC, mean(σ_o_post), sqrt(mean(abs2, σ_o_post)))
mean_y_pred = map(mean, eachslice(y_pred; dims=(1, 2)))
#describe(mean_y_pred - y_o)
histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_true)

# look at θP, θM1 of first site
intm = ComponentArrayInterpreter(int_θdoubleMM(1:length(int_θdoubleMM)), (n_sample_pred,))
ζs1c = intm(ζs[1:length(int_θdoubleMM), :])
vcat(θP_true, θM_true)
histogram(exp.(ζs1c[:r0, :]))
histogram(exp.(ζs1c[:K2, :]))
histogram(exp.(ζs1c[:r1, :]))
histogram(exp.(ζs1c[:K1, :]))
# all parameters estimated to high (true not in cf bounds)
scatterplot(ζs1c[:r1, :], ζs1c[:K1, :]) # r1 and K1 strongly correlated (from θM)
scatterplot(ζs1c[:r0, :], ζs1c[:K2, :]) # r0 and K also correlated (from θP)
scatterplot(ζs1c[:r0, :], ζs1c[:K1, :]) # no correlation (modeled independent)

# TODO compare distributions to MC sample





res = Optimization.solve(
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);

l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
scatterplot(vec(θMs_true), vec(θMs))
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
scatterplot(vec(y_pred), vec(y_o))
hcat(par_templates.θP, int_ϕθP(res.u).θP)
7 changes: 7 additions & 0 deletions src/HybridVariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using ComponentArrays: ComponentArrays as CA
using Random
using StatsBase # fit ZScoreTransform
using Combinatorics # gen_hybridcase_synthetic/combinations
using GPUArraysCore
using LinearAlgebra
using CUDA
using ChainRulesCore

export ComponentArrayInterpreter, flatten1, get_concrete
include("ComponentArrayInterpreter.jl")
Expand All @@ -25,6 +29,9 @@ include("gencovar.jl")
export callback_loss
include("util_opt.jl")

#export - all internal
include("cholesky.jl")

export DoubleMM
include("DoubleMM/DoubleMM.jl")

Expand Down
Loading
Loading