Skip to content

Commit b60f297

Browse files
authored
Merge pull request #13 from EarthyScience/dev
implement HybridProblem
2 parents 2b04312 + 3086e10 commit b60f297

23 files changed

+393
-118
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1212
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -38,6 +39,7 @@ Flux = "v0.15.2, 0.16"
3839
GPUArraysCore = "0.1, 0.2"
3940
LinearAlgebra = "1.10.0"
4041
Lux = "1.4.2"
42+
MLUtils = "0.4.5"
4143
Random = "1.10.0"
4244
SimpleChains = "0.4"
4345
StatsBase = "0.34.4"

dev/doubleMM.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using MLUtils
1212
import Zygote
1313

1414
using CUDA
15-
using TransformVariables
1615
using OptimizationOptimisers
16+
using Bijectors
1717
using UnicodePlots
1818

1919
const case = DoubleMM.DoubleMMCase()
@@ -24,13 +24,13 @@ rng = StableRNG(111)
2424

2525
par_templates = get_hybridcase_par_templates(case; scenario)
2626

27-
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
27+
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
2828

29-
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
29+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
3030
) = gen_hybridcase_synthetic(case, rng; scenario);
3131

3232
#----- fit g to θMs_true
33-
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
33+
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
3434

3535
function loss_g(ϕg, x, g)
3636
ζMs = g(x, ϕg) # predict the log of the parameters
@@ -51,7 +51,7 @@ loss_g(ϕg_opt1, xM, g)
5151
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
5252
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
5353

54-
f = gen_hybridcase_PBmodel(case; scenario)
54+
f = get_hybridcase_PBmodel(case; scenario)
5555

5656
#----------- fit g and θP to y_o
5757
() -> begin
@@ -84,6 +84,9 @@ end
8484
#---------- HVI
8585
logσ2y = 2 .* log.(σ_o)
8686
n_MC = 3
87+
transP = elementwise(exp)
88+
transM = Stacked(elementwise(identity), elementwise(exp))
89+
8790
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
8891
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
8992
ϕ_true = ϕ
@@ -188,7 +191,7 @@ end
188191

189192
ϕ = ϕ_ini |> Flux.gpu;
190193
xM_gpu = xM |> Flux.gpu;
191-
g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario);
194+
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
192195

193196
# otpimize using LUX
194197
() -> begin
@@ -224,7 +227,8 @@ gr = Zygote.gradient(fcost,
224227
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
225228
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
226229

227-
train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch)
230+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
231+
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
228232

229233
optf = Optimization.OptimizationFunction(
230234
(ϕ, data) -> begin

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ module HybridVariationalInferenceFluxExt
22

33
using HybridVariationalInference, Flux
44
using HybridVariationalInference: HybridVariationalInference as HVI
5+
using ComponentArrays: ComponentArrays as CA
56

67
struct FluxApplicator{RT} <: AbstractModelApplicator
78
rebuild::RT
89
end
910

1011
function HVI.construct_FluxApplicator(m::Chain)
11-
_, rebuild = destructure(m)
12-
FluxApplicator(rebuild)
12+
ϕ, rebuild = destructure(m)
13+
FluxApplicator(rebuild), ϕ
1314
end
1415

1516
function HVI.apply_model(app::FluxApplicator, x, ϕ)
@@ -25,7 +26,14 @@ function __init__()
2526
HVI.set_default_GPUHandler(FluxGPUDataHandler())
2627
end
2728

28-
function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
29+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
30+
args...; kwargs...)
31+
# constructor with Flux.Chain
32+
g, ϕg = construct_FluxApplicator(g_chain)
33+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
34+
end
35+
36+
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
2937
scenario::NTuple = ())
3038
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
3139
FloatType = get_hybridcase_FloatType(case; scenario)
@@ -39,8 +47,9 @@ function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
3947
# dense layer without bias that maps to n outputs and `identity` activation
4048
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
4149
)
42-
ϕ, _ = destructure(g_chain)
43-
construct_FluxApplicator(g_chain), ϕ
50+
construct_FluxApplicator(g_chain)
4451
end
4552

53+
54+
4655
end # module

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,26 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator
1010
int_ϕ::IT
1111
end
1212

13-
function HVI.construct_LuxApplicator(m::Chain; device = gpu_device())
13+
function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
1414
ps, st = Lux.setup(Random.default_rng(), m)
15-
ps_ca = CA.ComponentArray(ps)
15+
ps_ca = float_type.(CA.ComponentArray(ps))
1616
st = st |> device
1717
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
1818
#stateful_layer(x_o_gpu[:, 1:n_site_batch], ps_ca)
1919
int_ϕ = get_concrete(ComponentArrayInterpreter(ps_ca))
20-
LuxApplicator(stateful_layer, int_ϕ)
20+
LuxApplicator(stateful_layer, int_ϕ), ps_ca
2121
end
2222

2323
function HVI.apply_model(app::LuxApplicator, x, ϕ)
2424
ϕc = app.int_ϕ(ϕ)
2525
app.stateful_layer(x, ϕc)
2626
end
2727

28+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
29+
args...; device = gpu_device(), kwargs...)
30+
# constructor with SimpleChain
31+
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
32+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
33+
end
34+
2835
end # module

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,29 @@ module HybridVariationalInferenceSimpleChainsExt
33
using HybridVariationalInference, SimpleChains
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using StatsFuns: logistic
6+
using ComponentArrays: ComponentArrays as CA
7+
8+
69

710
struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
811
m::MT
912
end
1013

11-
HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
14+
function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
15+
ϕ = SimpleChains.init_params(m, FloatType);
16+
SimpleChainsApplicator(m), ϕ
17+
end
1218

1319
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1420

15-
function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
21+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
22+
args...; kwargs...)
23+
# constructor with SimpleChain
24+
g, ϕg = construct_SimpleChainsApplicator(g_chain)
25+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
26+
end
27+
28+
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
1629
scenario::NTuple=())
1730
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
1831
FloatType = get_hybridcase_FloatType(case; scenario)
@@ -39,8 +52,7 @@ function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
3952
TurboDense{false}(identity, n_out)
4053
)
4154
end
42-
ϕ = SimpleChains.init_params(g_chain, FloatType);
43-
SimpleChainsApplicator(g_chain), ϕ
55+
construct_SimpleChainsApplicator(g_chain, FloatType)
4456
end
4557

4658
end # module

src/DoubleMM/DoubleMM.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
module DoubleMM
22

33
using HybridVariationalInference
4+
using HybridVariationalInference: HybridVariationalInference as HVI
45
using ComponentArrays: ComponentArrays as CA
56
using Random
67
using Combinatorics
78
using StatsFuns: logistic
9+
using Bijectors
810

911

12+
export f_doubleMM, xP_S1, xP_S2
1013
include("f_doubleMM.jl")
1114

12-
export f_doubleMM, S1, S2
1315

1416
end

src/DoubleMM/f_doubleMM.jl

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,79 @@
11
struct DoubleMMCase <: AbstractHybridCase end
22

3-
const S1 = [1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
4-
const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
53

6-
θP = CA.ComponentVector(r0 = 0.3, K2 = 2.0)
7-
θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
4+
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
5+
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
6+
7+
transP = elementwise(exp)
8+
transM = Stacked(elementwise(identity), elementwise(exp))
9+
810

911
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
1012

11-
function f_doubleMM::AbstractVector)
13+
function f_doubleMM::AbstractVector, x)
1214
# extract parameters not depending on order, i.e whether they are in θP or θM
1315
θc = int_θdoubleMM(θ)
1416
r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)]
15-
y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
17+
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
1618
return (y)
1719
end
1820

19-
function HybridVariationalInference.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
21+
function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
2022
(; θP, θM)
2123
end
2224

23-
function HybridVariationalInference.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
25+
function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ())
26+
(; transP, transM)
27+
end
28+
29+
function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
2430
n_covar_pc = 2
2531
n_covar = n_covar_pc + 3 # linear dependent
26-
n_site = 10^n_covar_pc
32+
#n_site = 10^n_covar_pc
2733
n_batch = 10
2834
n_θM = length(θM)
2935
n_θP = length(θP)
30-
(; n_covar, n_site, n_batch, n_θM, n_θP)
36+
#(; n_covar, n_site, n_batch, n_θM, n_θP)
37+
(; n_covar, n_batch, n_θM, n_θP)
3138
end
3239

33-
function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
34-
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
40+
function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
41+
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
3542
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
36-
pred_sites = applyf(fsite, θMs, θP, x)
43+
pred_sites = applyf(f_doubleMM, θMs, θP, x)
3744
pred_global = eltype(pred_sites)[]
3845
return pred_global, pred_sites
3946
end
4047
end
4148

42-
function HybridVariationalInference.get_hybridcase_FloatType(::DoubleMMCase; scenario)
43-
return Float32
44-
end
49+
# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
50+
# return Float32
51+
# end
4552

46-
function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
53+
const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
54+
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
55+
56+
function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
4757
scenario = ())
4858
n_covar_pc = 2
49-
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
59+
n_site = 200
60+
(; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
5061
FloatType = get_hybridcase_FloatType(case; scenario)
5162
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
5263
rhodec = 8, is_using_dropout = false)
5364
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
5465
# normalize to be distributed around the prescribed true values
55-
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1))
56-
f = gen_hybridcase_PBmodel(case; scenario)
57-
xP = fill((), n_site)
58-
y_global_true, y_true = f(θP, θMs_true, zip())
59-
σ_o = 0.01
66+
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
67+
f = get_hybridcase_PBmodel(case; scenario)
68+
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
69+
y_global_true, y_true = f(θP, θMs_true, xP)
70+
σ_o = FloatType(0.01)
6071
#σ_o = 0.002
61-
y_global_o = y_global_true .+ randn(rng, size(y_global_true)) .* σ_o
62-
y_o = y_true .+ randn(rng, size(y_true)) .* σ_o
72+
y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o
73+
y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o
6374
(;
6475
xM,
76+
n_site,
6577
θP_true = θP,
6678
θMs_true,
6779
xP,
@@ -72,3 +84,6 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase,
7284
σ_o = fill(σ_o, size(y_true,1)),
7385
)
7486
end
87+
88+
89+

src/HybridProblem.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
struct HybridProblem <: AbstractHybridCase
2+
θP
3+
θM
4+
transP
5+
transM
6+
n_covar
7+
n_batch
8+
f
9+
g
10+
ϕg
11+
train_loader
12+
# inner constructor to constrain the types
13+
function HybridProblem(
14+
θP::CA.ComponentVector, θM::CA.ComponentVector,
15+
g::AbstractModelApplicator, ϕg,
16+
f::Function,
17+
transM::Union{Function, Bijectors.Transform},
18+
transP::Union{Function, Bijectors.Transform},
19+
n_covar::Integer, n_batch::Integer,
20+
train_loader::DataLoader)
21+
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader)
22+
end
23+
end
24+
25+
function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ())
26+
(; θP = prob.θP, θM = prob.θM)
27+
end
28+
29+
function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ())
30+
n_θM = length(prob.θM)
31+
n_θP = length(prob.θP)
32+
(; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP)
33+
end
34+
35+
function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ())
36+
prob.f
37+
end
38+
39+
function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ());
40+
prob.g, prob.ϕg
41+
end
42+
43+
function get_hybridcase_train_dataloader(
44+
prob::HybridProblem, rng::AbstractRNG = Random.default_rng();
45+
scenario = ())
46+
return(prob.train_loader)
47+
end
48+
49+
50+
# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
51+
# eltype(prob.θM)
52+
# end
53+
54+
55+

0 commit comments

Comments
 (0)