Skip to content
Merged

Dev #14

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ jobs:
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
#fail_ci_if_error: false
fail_ci_if_error: true
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
121 changes: 48 additions & 73 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Statistics
using ComponentArrays: ComponentArrays as CA

using SimpleChains
import Flux # to allow for FluxMLEngine and cpu()
import Flux
using MLUtils
import Zygote

Expand All @@ -17,41 +17,43 @@ using Bijectors
using UnicodePlots

const case = DoubleMM.DoubleMMCase()
const MLengine = Val(nameof(SimpleChains))
const FluxMLengine = Val(nameof(Flux))
scenario = (:default,)
rng = StableRNG(111)

par_templates = get_hybridcase_par_templates(case; scenario)

(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
#n_covar = get_hybridcase_n_covar(case; scenario)
#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)

(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
) = gen_hybridcase_synthetic(rng, case; scenario);

n_covar = size(xM,1)

(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
) = gen_hybridcase_synthetic(case, rng; scenario);

#----- fit g to θMs_true
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
(; transP, transM) = get_hybridcase_transforms(case; scenario)

function loss_g(ϕg, x, g)
function loss_g(ϕg, x, g, transM)
ζMs = g(x, ϕg) # predict the log of the parameters
θMs = exp.(ζMs)
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
loss = sum(abs2, θMs .- θMs_true)
return loss, θMs
end
loss_g(ϕg0, xM, g)
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
loss_g(ϕg0, xM, g, transM)

optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, ϕg0);
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);

ϕg_opt1 = res.u;
loss_g(ϕg_opt1, xM, g)
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
scatterplot(vec(θMs_true), vec(θMs_pred))

f = get_hybridcase_PBmodel(case; scenario)
py = get_hybridcase_neg_logden_obs(case; scenario)

#----------- fit g and θP to y_o
() -> begin
Expand All @@ -62,7 +64,7 @@ f = get_hybridcase_PBmodel(case; scenario)
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true

# Pass the site-data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)

loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
l1 = loss_gf(p0, train_loader.data...)[1]
Expand All @@ -82,15 +84,16 @@ f = get_hybridcase_PBmodel(case; scenario)
end

#---------- HVI
logσ2y = 2 .* log.(σ_o)
n_MC = 3
transP = elementwise(exp)
transM = Stacked(elementwise(identity), elementwise(exp))
(; transP, transM) = get_hybridcase_transforms(case; scenario)
FT = get_hybridcase_float_type(case; scenario)

(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
ϕ_true = ϕ



() -> begin
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
Expand Down Expand Up @@ -149,49 +152,22 @@ transM = Stacked(elementwise(identity), elementwise(exp))
ϕ_true = inverse_ca(trans_gu, ϕt_true)
end

ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕ_true[[:unc]]); # scratch
ϕ_ini0 = ζ = reduce(
vcat, (
ϕ_true[[:μP]] .* FT(0.001), CA.ComponentVector(ϕg = ϕg0), ϕ_true[[:unc]])) # scratch
#
# true values
ϕ_ini = ζ = vcat(ϕ_true[[:μP, :ϕg]] .* 1.2, ϕ_true[[:unc]]); # slight disturbance
ϕ_ini = ζ = reduce(
vcat, (
ϕ_true[[:μP]] .- FT(0.1), ϕ_true[[:ϕg]] .* FT(1.1), ϕ_true[[:unc]])) # slight disturbance
# hardcoded from HMC inversion
ϕ_ini.unc.coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
ϕ_ini.unc.logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
mean_σ_o_MC = 0.006042

# test cost function and gradient
() -> begin
neg_elbo_transnorm_gf(rng, g, f, ϕ_true, y_o[:, 1:n_batch], xM[:, 1:n_batch],
transPMs_batch, map(get_concrete, interpreters);
n_MC = 8, logσ2y)
Zygote.gradient(
ϕ -> neg_elbo_transnorm_gf(
rng, g, f, ϕ, y_o[:, 1:n_batch], xM[:, 1:n_batch],
transPMs_batch, interpreters; n_MC = 8, logσ2y),
CA.getdata(ϕ_true))
end

# optimize using SimpleChains
() -> begin
train_loader = MLUtils.DataLoader((xM, y_o), batchsize = n_batch)

optf = Optimization.OptimizationFunction(
(ϕ, data) -> begin
xM, y_o = data
neg_elbo_transnorm_gf(
rng, g, f, ϕ, y_o, xM, transPMs_batch,
map(get_concrete, interpreters_g); n_MC = 5, logσ2y)
end,
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader)
res = Optimization.solve(
optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 800)
#optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
end

ϕ = ϕ_ini |> Flux.gpu;
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
xM_gpu = xM |> Flux.gpu;
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
scenario_flux = (scenario..., :use_Flux)
g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux);

# otpimize using LUX
() -> begin
Expand All @@ -216,27 +192,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario
g_flux = g_luxs
end

function fcost(ϕ, xM, y_o)
neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o,
xM, transPMs_batch, map(get_concrete, interpreters);
n_MC = 8, logσ2y = logσ2y)
function fcost(ϕ, xM, y_o, y_unc)
neg_elbo_transnorm_gf(rng, CA.getdata(ϕ), g_flux, transPMs_batch, f, py,
xM, xP, y_o, y_unc, map(get_concrete, interpreters);
n_MC = 8)
end
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch])
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch])
#Zygote.gradient(fcost, ϕ) |> cpu;
gr = Zygote.gradient(fcost,
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]),
CA.getdata(y_o[:, 1:n_batch]), CA.getdata(y_unc[:, 1:n_batch]));
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)

train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux))

optf = Optimization.OptimizationFunction(
(ϕ, data) -> begin
xM, y_o = data
fcost(ϕ, xM, y_o)
# neg_elbo_transnorm_gf(
# rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
# map(get_concrete, interpreters); n_MC = 5, logσ2y)
xM, xP, y_o, y_unc = data
fcost(ϕ, xM, y_o, y_unc)
end,
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(
Expand All @@ -256,7 +230,7 @@ end
ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu
ϕunc_VI = interpreters.unc(ζ_VIc.unc)

hcat(θP_true, exp.(ζ_VIc.μP))
hcat(log.(θP_true), ϕ_ini.μP, ζ_VIc.μP)
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
#lineplot!(plt, 0.0, 1.1, identity)
#
Expand All @@ -266,11 +240,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
# test predicting correct obs-uncertainty of predictive posterior
n_sample_pred = 200

y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters;
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
get_transPMs, get_ca_int_PMs, n_sample_pred);
size(y_pred) # n_obs x n_site, n_sample_pred

σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
σ_o = exp.(y_unc[:,1] / 2)

#describe(σ_o_post)
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),
Expand All @@ -282,7 +257,7 @@ histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_tru

# look at θP, θM1 of first site
intm_PMs_gen = get_ca_int_PMs(n_site)
ζs, _σ = HVI.generate_ζ(rng, g_flux, f, res.u, xM_gpu,
ζs, _σ = HVI.generate_ζ(rng, g_flux, res.u, xM_gpu,
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred);
ζs = ζs |> Flux.cpu;
θPM = vcat(θP_true, θMs_true[:, 1])
Expand Down
29 changes: 17 additions & 12 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
using HybridVariationalInference, Flux
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using Random

struct FluxApplicator{RT} <: AbstractModelApplicator
rebuild::RT
end

function HVI.construct_FluxApplicator(m::Chain)
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
# TODO: care fore rng and float_type
ϕ, rebuild = destructure(m)
FluxApplicator(rebuild), ϕ
end
Expand All @@ -26,18 +28,21 @@
HVI.set_default_GPUHandler(FluxGPUDataHandler())
end

function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
args...; kwargs...)
# constructor with Flux.Chain
g, ϕg = construct_FluxApplicator(g_chain)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
end
# function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
# args...; kwargs...)
# # constructor with Flux.Chain
# g, ϕg = construct_FluxApplicator(g_chain)
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
# end

function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
function HVI.construct_3layer_MLApplicator(

Check warning on line 38 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L38

Added line #L38 was not covered by tests
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux};
scenario::NTuple = ())
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
FloatType = get_hybridcase_FloatType(case; scenario)
n_out = n_θM
(;θM) = get_hybridcase_par_templates(case; scenario)
n_out = length(θM)
n_covar = get_hybridcase_n_covar(case; scenario)

Check warning on line 43 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L41-L43

Added lines #L41 - L43 were not covered by tests
#(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
float_type = get_hybridcase_float_type(case; scenario)

Check warning on line 45 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L45

Added line #L45 was not covered by tests
is_using_dropout = :use_dropout ∈ scenario
is_using_dropout && error("dropout scenario not supported with Flux yet.")
g_chain = Flux.Chain(
Expand All @@ -47,7 +52,7 @@
# dense layer without bias that maps to n outputs and `identity` activation
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
)
construct_FluxApplicator(g_chain)
construct_ChainsApplicator(rng, g_chain, float_type)

Check warning on line 55 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L55

Added line #L55 was not covered by tests
end


Expand Down
17 changes: 9 additions & 8 deletions ext/HybridVariationalInferenceLuxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator
int_ϕ::IT
end

function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
ps, st = Lux.setup(Random.default_rng(), m)
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type=Float32; device = gpu_device())
ps, st = Lux.setup(rng, m)
ps_ca = float_type.(CA.ComponentArray(ps))
st = st |> device
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
Expand All @@ -25,11 +25,12 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ)
app.stateful_layer(x, ϕc)
end

function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
args...; device = gpu_device(), kwargs...)
# constructor with SimpleChain
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
end
# function HVI.HybridProblem(rng::AbstractRNG,
# θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
# args...; device = gpu_device(), kwargs...)
# # constructor with SimpleChain
# g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM); device)
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
# end

end # module
26 changes: 11 additions & 15 deletions ext/HybridVariationalInferenceSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,28 @@ using HybridVariationalInference, SimpleChains
using HybridVariationalInference: HybridVariationalInference as HVI
using StatsFuns: logistic
using ComponentArrays: ComponentArrays as CA
using Random



struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
m::MT
end

function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
ϕ = SimpleChains.init_params(m, FloatType);
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::SimpleChain, FloatType=Float32)
ϕ = SimpleChains.init_params(m, FloatType; rng);
SimpleChainsApplicator(m), ϕ
end

HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)

function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
args...; kwargs...)
# constructor with SimpleChain
g, ϕg = construct_SimpleChainsApplicator(g_chain)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
end

function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
scenario::NTuple=())
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
FloatType = get_hybridcase_FloatType(case; scenario)
n_out = n_θM
function HVI.construct_3layer_MLApplicator(
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains};
scenario::NTuple = ())
n_covar = get_hybridcase_n_covar(case; scenario)
FloatType = get_hybridcase_float_type(case; scenario)
(;θM) = get_hybridcase_par_templates(case; scenario)
n_out = length(θM)
is_using_dropout = :use_dropout ∈ scenario
g_chain = if is_using_dropout
SimpleChain(
Expand All @@ -52,7 +48,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
TurboDense{false}(identity, n_out)
)
end
construct_SimpleChainsApplicator(g_chain, FloatType)
construct_ChainsApplicator(rng, g_chain, FloatType)
end

end # module
Loading