Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Expand All @@ -38,6 +39,7 @@ Flux = "v0.15.2, 0.16"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10.0"
Lux = "1.4.2"
MLUtils = "0.4.5"
Random = "1.10.0"
SimpleChains = "0.4"
StatsBase = "0.34.4"
Expand Down
18 changes: 11 additions & 7 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ using MLUtils
import Zygote

using CUDA
using TransformVariables
using OptimizationOptimisers
using Bijectors
using UnicodePlots

const case = DoubleMM.DoubleMMCase()
Expand All @@ -24,13 +24,13 @@ rng = StableRNG(111)

par_templates = get_hybridcase_par_templates(case; scenario)

(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)

(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
) = gen_hybridcase_synthetic(case, rng; scenario);

#----- fit g to θMs_true
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);

function loss_g(ϕg, x, g)
ζMs = g(x, ϕg) # predict the log of the parameters
Expand All @@ -51,7 +51,7 @@ loss_g(ϕg_opt1, xM, g)
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9

f = gen_hybridcase_PBmodel(case; scenario)
f = get_hybridcase_PBmodel(case; scenario)

#----------- fit g and θP to y_o
() -> begin
Expand Down Expand Up @@ -84,6 +84,9 @@ end
#---------- HVI
logσ2y = 2 .* log.(σ_o)
n_MC = 3
transP = elementwise(exp)
transM = Stacked(elementwise(identity), elementwise(exp))

(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
ϕ_true = ϕ
Expand Down Expand Up @@ -188,7 +191,7 @@ end

ϕ = ϕ_ini |> Flux.gpu;
xM_gpu = xM |> Flux.gpu;
g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario);
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);

# otpimize using LUX
() -> begin
Expand Down Expand Up @@ -224,7 +227,8 @@ gr = Zygote.gradient(fcost,
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)

train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch)
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))

optf = Optimization.OptimizationFunction(
(ϕ, data) -> begin
Expand Down
19 changes: 14 additions & 5 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

using HybridVariationalInference, Flux
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA

struct FluxApplicator{RT} <: AbstractModelApplicator
rebuild::RT
end

function HVI.construct_FluxApplicator(m::Chain)
_, rebuild = destructure(m)
FluxApplicator(rebuild)
ϕ, rebuild = destructure(m)
FluxApplicator(rebuild), ϕ
end

function HVI.apply_model(app::FluxApplicator, x, ϕ)
Expand All @@ -25,7 +26,14 @@
HVI.set_default_GPUHandler(FluxGPUDataHandler())
end

function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,

Check warning on line 29 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L29

Added line #L29 was not covered by tests
args...; kwargs...)
# constructor with Flux.Chain
g, ϕg = construct_FluxApplicator(g_chain)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)

Check warning on line 33 in ext/HybridVariationalInferenceFluxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceFluxExt.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end

function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
scenario::NTuple = ())
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
FloatType = get_hybridcase_FloatType(case; scenario)
Expand All @@ -39,8 +47,9 @@
# dense layer without bias that maps to n outputs and `identity` activation
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
)
ϕ, _ = destructure(g_chain)
construct_FluxApplicator(g_chain), ϕ
construct_FluxApplicator(g_chain)
end



end # module
13 changes: 10 additions & 3 deletions ext/HybridVariationalInferenceLuxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@
int_ϕ::IT
end

function HVI.construct_LuxApplicator(m::Chain; device = gpu_device())
function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
ps, st = Lux.setup(Random.default_rng(), m)
ps_ca = CA.ComponentArray(ps)
ps_ca = float_type.(CA.ComponentArray(ps))
st = st |> device
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
#stateful_layer(x_o_gpu[:, 1:n_site_batch], ps_ca)
int_ϕ = get_concrete(ComponentArrayInterpreter(ps_ca))
LuxApplicator(stateful_layer, int_ϕ)
LuxApplicator(stateful_layer, int_ϕ), ps_ca
end

function HVI.apply_model(app::LuxApplicator, x, ϕ)
ϕc = app.int_ϕ(ϕ)
app.stateful_layer(x, ϕc)
end

function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,

Check warning on line 28 in ext/HybridVariationalInferenceLuxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceLuxExt.jl#L28

Added line #L28 was not covered by tests
args...; device = gpu_device(), kwargs...)
# constructor with SimpleChain
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)

Check warning on line 32 in ext/HybridVariationalInferenceLuxExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HybridVariationalInferenceLuxExt.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
end

end # module
20 changes: 16 additions & 4 deletions ext/HybridVariationalInferenceSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,29 @@ module HybridVariationalInferenceSimpleChainsExt
using HybridVariationalInference, SimpleChains
using HybridVariationalInference: HybridVariationalInference as HVI
using StatsFuns: logistic
using ComponentArrays: ComponentArrays as CA



struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
m::MT
end

HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
ϕ = SimpleChains.init_params(m, FloatType);
SimpleChainsApplicator(m), ϕ
end

HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)

function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
args...; kwargs...)
# constructor with SimpleChain
g, ϕg = construct_SimpleChainsApplicator(g_chain)
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
end

function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
scenario::NTuple=())
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
FloatType = get_hybridcase_FloatType(case; scenario)
Expand All @@ -39,8 +52,7 @@ function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
TurboDense{false}(identity, n_out)
)
end
ϕ = SimpleChains.init_params(g_chain, FloatType);
SimpleChainsApplicator(g_chain), ϕ
construct_SimpleChainsApplicator(g_chain, FloatType)
end

end # module
4 changes: 3 additions & 1 deletion src/DoubleMM/DoubleMM.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
module DoubleMM

using HybridVariationalInference
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using Random
using Combinatorics
using StatsFuns: logistic
using Bijectors


export f_doubleMM, xP_S1, xP_S2
include("f_doubleMM.jl")

export f_doubleMM, S1, S2

end
65 changes: 40 additions & 25 deletions src/DoubleMM/f_doubleMM.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,79 @@
struct DoubleMMCase <: AbstractHybridCase end

const S1 = [1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]

θP = CA.ComponentVector(r0 = 0.3, K2 = 2.0)
θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)

transP = elementwise(exp)
transM = Stacked(elementwise(identity), elementwise(exp))


const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))

function f_doubleMM(θ::AbstractVector)
function f_doubleMM(θ::AbstractVector, x)
# extract parameters not depending on order, i.e whether they are in θP or θM
θc = int_θdoubleMM(θ)
r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)]
y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
return (y)
end

function HybridVariationalInference.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
(; θP, θM)
end

function HybridVariationalInference.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ())
(; transP, transM)
end

function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
n_covar_pc = 2
n_covar = n_covar_pc + 3 # linear dependent
n_site = 10^n_covar_pc
#n_site = 10^n_covar_pc
n_batch = 10
n_θM = length(θM)
n_θP = length(θP)
(; n_covar, n_site, n_batch, n_θM, n_θP)
#(; n_covar, n_site, n_batch, n_θM, n_θP)
(; n_covar, n_batch, n_θM, n_θP)
end

function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
pred_sites = applyf(fsite, θMs, θP, x)
pred_sites = applyf(f_doubleMM, θMs, θP, x)
pred_global = eltype(pred_sites)[]
return pred_global, pred_sites
end
end

function HybridVariationalInference.get_hybridcase_FloatType(::DoubleMMCase; scenario)
return Float32
end
# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
# return Float32
# end

function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]

function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
scenario = ())
n_covar_pc = 2
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
n_site = 200
(; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
FloatType = get_hybridcase_FloatType(case; scenario)
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
rhodec = 8, is_using_dropout = false)
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
# normalize to be distributed around the prescribed true values
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1))
f = gen_hybridcase_PBmodel(case; scenario)
xP = fill((), n_site)
y_global_true, y_true = f(θP, θMs_true, zip())
σ_o = 0.01
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
f = get_hybridcase_PBmodel(case; scenario)
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
y_global_true, y_true = f(θP, θMs_true, xP)
σ_o = FloatType(0.01)
#σ_o = 0.002
y_global_o = y_global_true .+ randn(rng, size(y_global_true)) .* σ_o
y_o = y_true .+ randn(rng, size(y_true)) .* σ_o
y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o
y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o
(;
xM,
n_site,
θP_true = θP,
θMs_true,
xP,
Expand All @@ -72,3 +84,6 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase,
σ_o = fill(σ_o, size(y_true,1)),
)
end



55 changes: 55 additions & 0 deletions src/HybridProblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
struct HybridProblem <: AbstractHybridCase
θP
θM
transP
transM
n_covar
n_batch
f
g
ϕg
train_loader
# inner constructor to constrain the types
function HybridProblem(
θP::CA.ComponentVector, θM::CA.ComponentVector,
g::AbstractModelApplicator, ϕg,
f::Function,
transM::Union{Function, Bijectors.Transform},
transP::Union{Function, Bijectors.Transform},
n_covar::Integer, n_batch::Integer,
train_loader::DataLoader)
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader)
end
end

function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ())
(; θP = prob.θP, θM = prob.θM)
end

function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ())
n_θM = length(prob.θM)
n_θP = length(prob.θP)
(; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP)

Check warning on line 32 in src/HybridProblem.jl

View check run for this annotation

Codecov / codecov/patch

src/HybridProblem.jl#L29-L32

Added lines #L29 - L32 were not covered by tests
end

function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ())
prob.f
end

function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ());
prob.g, prob.ϕg
end

function get_hybridcase_train_dataloader(
prob::HybridProblem, rng::AbstractRNG = Random.default_rng();
scenario = ())
return(prob.train_loader)
end


# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
# eltype(prob.θM)
# end



Loading
Loading