Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Expand All @@ -34,12 +36,14 @@ BlockDiagonals = "0.1.42"
CUDA = "5.5.2"
ChainRulesCore = "1.25"
Combinatorics = "1.0.2"
CommonSolve = "0.2.4"
ComponentArrays = "0.15.19"
Flux = "v0.15.2, 0.16"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10.0"
Lux = "1.4.2"
MLUtils = "0.4.5"
Optimization = "3.19.3, 4"
Random = "1.10.0"
SimpleChains = "0.4"
StatsBase = "0.34.4"
Expand Down
225 changes: 168 additions & 57 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,195 @@ using StableRNGs
using Random
using Statistics
using ComponentArrays: ComponentArrays as CA

using Optimization
using OptimizationOptimisers # Adam
using UnicodePlots
using SimpleChains
import Flux
using Flux
using MLUtils
import Zygote

using CUDA
using OptimizationOptimisers
using Bijectors
using UnicodePlots

const prob = DoubleMM.DoubleMMCase()
scenario = (:default,)
rng = StableRNG(111)

par_templates = get_hybridproblem_par_templates(prob; scenario)

#n_covar = get_hybridproblem_n_covar(prob; scenario)
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
rng = StableRNG(114)
scenario = NTuple{0, Symbol}()
scenario = (:use_Flux,)

#------ setup synthetic data and training data loader
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
) = gen_hybridcase_synthetic(rng, prob; scenario);

n_covar = size(xM,1)
) = gen_hybridcase_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
xM_cpu = xM
if :use_Flux ∈ scenario
xM = CuArray(xM_cpu)
end
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc);
batchsize = n_batch, partial = false)
σ_o = exp(first(y_unc)/2)

# assign the train_loader, otherwise it eatch time creates another version of synthetic data
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)

#------- pointwise hybrid model fit
solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
#solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
(; ϕ, resopt) = solve(prob0, solver; scenario,
rng, callback = callback_loss(100), maxiters = 1200);
# update the problem with optimized parameters
prob0o = HVI.update(prob0; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
scatterplot(θMs_true[1,:], θMs[1,:])
scatterplot(θMs_true[2,:], θMs[2,:])

# do a few steps without minibatching,
# by providing the data rather than the DataLoader
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
(; ϕ, resopt) = solve(prob0o, solver1; scenario, rng,
callback = callback_loss(20), maxiters = 600);
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP);
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
scatterplot(θMs_true[1,:], θMs[1,:])
scatterplot(θMs_true[2,:], θMs[2,:])
prob1o.θP
scatterplot(vec(y_true), vec(y_pred))

# still overestimating θMs

() -> begin # with more iterations?
prob2 = prob1o
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
callback = callback_loss(20), maxiters = 600);
prob2o = update(prob2; ϕg=ϕ.ϕg, θP=ϕ.θP)
y_pred_global, y_pred, θMs = gf(prob2o, xM, xP);
prob2o.θP
end


#----- fit g to θMs_true
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
#----------- fit g to true θMs
() -> begin
# and fit gf starting from true parameters
prob = prob0
g, ϕg0_cpu = get_hybridproblem_MLapplicator(prob; scenario);
ϕg0 = (:use_Flux ∈ scenario) ? CuArray(ϕg0_cpu) : ϕg0_cpu
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)

function loss_g(ϕg, x, g, transM; gpu_handler = HVI.default_GPU_DataHandler)
ζMs = g(x, ϕg) # predict the log of the parameters
ζMs_cpu = gpu_handler(ζMs)
θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each column
loss = sum(abs2, θMs .- θMs_true)
return loss, θMs
end
loss_g(ϕg0, xM, g, transM)

optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, ϕg0);
res = Optimization.solve(optprob, Adam(0.015), callback = callback_loss(100), maxiters = 2000);

ϕg_opt1 = res.u;
l1, θMs = loss_g(ϕg_opt1, xM, g, transM)
#scatterplot(θMs_true[1,:], θMs[1,:])
scatterplot(θMs_true[2,:], θMs[2,:]) # able to fit θMs[2,:]

prob3 = HVI.update(prob0, ϕg = Array(ϕg_opt1), θP = θP_true)
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
callback = callback_loss(50), maxiters = 600);
prob3o = HVI.update(prob3; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario);
scatterplot(θMs_true[2,:], θMs[2,:])
prob3o.θP
scatterplot(vec(y_true), vec(y_pred))
scatterplot(vec(y_true), vec(y_o))
scatterplot(vec(y_pred), vec(y_o))

function loss_g(ϕg, x, g, transM)
ζMs = g(x, ϕg) # predict the log of the parameters
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
loss = sum(abs2, θMs .- θMs_true)
return loss, θMs
() -> begin # optimized loss is indeed lower than with true parameters
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc)[1]
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc)[1]
#
loss_gf(vcat(prob2o.ϕg, prob2o.θP), xM, xP, y_o, y_unc)[1]
end
end
loss_g(ϕg0, xM, g, transM)

#----------- Hybrid Variational inference: HVI

optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, ϕg0);
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
using MLUtils
import Zygote

using CUDA
using Bijectors

ϕg_opt1 = res.u;
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
scatterplot(vec(θMs_true), vec(θMs_pred))
solver = HybridPosteriorSolver(; alg = Adam(0.01), n_batch = 60, n_MC = 3)
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
(; ϕ, θP, resopt) = solve(prob0o, solver; scenario,
rng, callback = callback_loss(100), maxiters = 800);
# update the problem with optimized parameters
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=θP)
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
scatterplot(θMs_true[1,:], θMs[1,:])
scatterplot(θMs_true[2,:], θMs[2,:])
hcat(θP_true, θP) # all parameters overestimated

f = get_hybridproblem_PBmodel(prob; scenario)
py = get_hybridproblem_neg_logden_obs(prob; scenario)

#----------- fit g and θP to y_o
() -> begin
# end2end inversion
#n_covar = get_hybridproblem_n_covar(prob; scenario)
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)

int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
ϕg = 1:length(ϕg0), θP = par_templates.θP))
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
n_covar = size(xM, 1)

# Pass the site-data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
#----- fit g to θMs_true
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)

loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
l1 = loss_gf(p0, train_loader.data...)[1]
function loss_g(ϕg, x, g, transM)
ζMs = g(x, ϕg) # predict the log of the parameters
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
loss = sum(abs2, θMs .- θMs_true)
return loss, θMs
end
loss_g(ϕg0, xM, g, transM)

optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
Optimization.AutoZygote())
optprob = OptimizationProblem(optf, p0, train_loader)
optprob = Optimization.OptimizationProblem(optf, ϕg0);
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);

res = Optimization.solve(
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)
ϕg_opt1 = res.u;
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
scatterplot(vec(θMs_true), vec(θMs_pred))

l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
scatterplot(vec(θMs_true), vec(θMs))
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
scatterplot(vec(y_pred), vec(y_o))
hcat(par_templates.θP, int_ϕθP(res.u).θP)
f = get_hybridproblem_PBmodel(prob; scenario)
py = get_hybridproblem_neg_logden_obs(prob; scenario)

#----------- fit g and θP to y_o
() -> begin
# end2end inversion

int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
ϕg = 1:length(ϕg0), θP = par_templates.θP))
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true

# Pass the site-data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)

loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
l1 = loss_gf(p0, train_loader.data...)[1]

optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
Optimization.AutoZygote())
optprob = OptimizationProblem(optf, p0, train_loader)

res = Optimization.solve(
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)

l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
scatterplot(vec(θMs_true), vec(θMs))
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
scatterplot(vec(y_pred), vec(y_o))
hcat(par_templates.θP, int_ϕθP(res.u).θP)
end
end

#---------- HVI
Expand All @@ -92,8 +205,6 @@ FT = get_hybridproblem_float_type(prob; scenario)
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
ϕ_true = ϕ



() -> begin
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
Expand Down Expand Up @@ -245,7 +356,7 @@ y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
size(y_pred) # n_obs x n_site, n_sample_pred

σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
σ_o = exp.(y_unc[:,1] / 2)
σ_o = exp.(y_unc[:, 1] / 2)

#describe(σ_o_post)
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),
Expand Down
5 changes: 5 additions & 0 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ function HVI.construct_3layer_MLApplicator(
construct_ChainsApplicator(rng, g_chain, float_type)
end

function HVI.cpu_ca(ca::CA.ComponentArray)
CA.ComponentArray(cpu(CA.getdata(ca)), CA.getaxes(ca))
end




end # module
Loading
Loading