Skip to content

Commit cdbfa10

Browse files
committed
implement HybridProblem
1 parent 2b04312 commit cdbfa10

14 files changed

+244
-39
lines changed

dev/doubleMM.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using MLUtils
1212
import Zygote
1313

1414
using CUDA
15-
using TransformVariables
1615
using OptimizationOptimisers
16+
using Bijectors
1717
using UnicodePlots
1818

1919
const case = DoubleMM.DoubleMMCase()
@@ -24,13 +24,13 @@ rng = StableRNG(111)
2424

2525
par_templates = get_hybridcase_par_templates(case; scenario)
2626

27-
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
27+
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
2828

29-
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
29+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
3030
) = gen_hybridcase_synthetic(case, rng; scenario);
3131

3232
#----- fit g to θMs_true
33-
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
33+
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
3434

3535
function loss_g(ϕg, x, g)
3636
ζMs = g(x, ϕg) # predict the log of the parameters
@@ -51,7 +51,7 @@ loss_g(ϕg_opt1, xM, g)
5151
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
5252
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
5353

54-
f = gen_hybridcase_PBmodel(case; scenario)
54+
f = get_hybridcase_PBmodel(case; scenario)
5555

5656
#----------- fit g and θP to y_o
5757
() -> begin
@@ -84,6 +84,9 @@ end
8484
#---------- HVI
8585
logσ2y = 2 .* log.(σ_o)
8686
n_MC = 3
87+
transP = elementwise(exp)
88+
transM = Stacked(elementwise(identity), elementwise(exp))
89+
8790
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
8891
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
8992
ϕ_true = ϕ
@@ -188,7 +191,7 @@ end
188191

189192
ϕ = ϕ_ini |> Flux.gpu;
190193
xM_gpu = xM |> Flux.gpu;
191-
g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario);
194+
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
192195

193196
# otpimize using LUX
194197
() -> begin

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function __init__()
2525
HVI.set_default_GPUHandler(FluxGPUDataHandler())
2626
end
2727

28-
function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
28+
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
2929
scenario::NTuple = ())
3030
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
3131
FloatType = get_hybridcase_FloatType(case; scenario)

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
1212

1313
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1414

15-
function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
15+
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
1616
scenario::NTuple=())
1717
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
1818
FloatType = get_hybridcase_FloatType(case; scenario)

src/DoubleMM/DoubleMM.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module DoubleMM
22

33
using HybridVariationalInference
4+
using HybridVariationalInference: HybridVariationalInference as HVI
45
using ComponentArrays: ComponentArrays as CA
56
using Random
67
using Combinatorics
78
using StatsFuns: logistic
9+
using Bijectors
810

911

1012
include("f_doubleMM.jl")

src/DoubleMM/f_doubleMM.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
66
θP = CA.ComponentVector(r0 = 0.3, K2 = 2.0)
77
θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
88

9+
transP = elementwise(exp)
10+
transM = Stacked(elementwise(identity), elementwise(exp))
11+
12+
913
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
1014

1115
function f_doubleMM::AbstractVector)
@@ -16,21 +20,26 @@ function f_doubleMM(θ::AbstractVector)
1620
return (y)
1721
end
1822

19-
function HybridVariationalInference.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
23+
function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
2024
(; θP, θM)
2125
end
2226

23-
function HybridVariationalInference.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
27+
function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ())
28+
(; transP, transM)
29+
end
30+
31+
function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
2432
n_covar_pc = 2
2533
n_covar = n_covar_pc + 3 # linear dependent
26-
n_site = 10^n_covar_pc
34+
#n_site = 10^n_covar_pc
2735
n_batch = 10
2836
n_θM = length(θM)
2937
n_θP = length(θP)
30-
(; n_covar, n_site, n_batch, n_θM, n_θP)
38+
#(; n_covar, n_site, n_batch, n_θM, n_θP)
39+
(; n_covar, n_batch, n_θM, n_θP)
3140
end
3241

33-
function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
42+
function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
3443
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
3544
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
3645
pred_sites = applyf(fsite, θMs, θP, x)
@@ -39,21 +48,22 @@ function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scena
3948
end
4049
end
4150

42-
function HybridVariationalInference.get_hybridcase_FloatType(::DoubleMMCase; scenario)
51+
function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
4352
return Float32
4453
end
4554

46-
function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
55+
function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
4756
scenario = ())
4857
n_covar_pc = 2
49-
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
58+
n_site = 200
59+
(; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
5060
FloatType = get_hybridcase_FloatType(case; scenario)
5161
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
5262
rhodec = 8, is_using_dropout = false)
5363
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
5464
# normalize to be distributed around the prescribed true values
5565
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1))
56-
f = gen_hybridcase_PBmodel(case; scenario)
66+
f = get_hybridcase_PBmodel(case; scenario)
5767
xP = fill((), n_site)
5868
y_global_true, y_true = f(θP, θMs_true, zip())
5969
σ_o = 0.01
@@ -62,6 +72,7 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase,
6272
y_o = y_true .+ randn(rng, size(y_true)) .* σ_o
6373
(;
6474
xM,
75+
n_site,
6576
θP_true = θP,
6677
θMs_true,
6778
xP,

src/HybridProblem.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
struct HybridProblem <: AbstractHybridCase
2+
θP
3+
θM
4+
transP
5+
transM
6+
n_covar
7+
n_batch
8+
f
9+
g
10+
ϕg
11+
# inner constructor to constrain the types
12+
function HybridProblem(
13+
θP::CA.ComponentVector, θM::CA.ComponentVector,
14+
transM::Union{Function, Bijectors.Transform},
15+
transP::Union{Function, Bijectors.Transform},
16+
n_covar::Integer, n_batch::Integer,
17+
f::Function, g::AbstractModelApplicator, ϕg)
18+
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg)
19+
end
20+
end
21+
22+
function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ())
23+
(; θP = prob.θP, θM = prob.θM)
24+
end
25+
26+
function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ())
27+
n_θM = length(prob.θM)
28+
n_θP = length(prob.θP)
29+
(; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP)
30+
end
31+
32+
function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ())
33+
prob.f
34+
end
35+
36+
function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ());
37+
prob.g, prob.ϕg
38+
end
39+
40+
function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
41+
eltype(prob.θM)
42+
end
43+
44+
45+

src/HybridVariationalInference.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ include("ModelApplicator.jl")
2222
export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler
2323
include("GPUDataHandler.jl")
2424

25-
export AbstractHybridCase, gen_hybridcase_MLapplicator, gen_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic,
26-
get_hybridcase_par_templates, gen_cov_pred
25+
export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic,
26+
get_hybridcase_par_templates, get_hybridcase_transforms, gen_cov_pred
2727
include("hybrid_case.jl")
2828

29+
export HybridProblem
30+
include("HybridProblem.jl")
31+
2932
export applyf, gf, get_loss_gf
3033
include("gf.jl")
3134

src/hybrid_case.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ for different cases of hybrid problem setups
44
55
For a specific case, provide functions that specify details
66
- get_hybridcase_par_templates
7+
- get_hybridcase_transforms
78
- get_hybridcase_sizes
8-
- gen_hybridcase_MLapplicator
9-
- gen_hybridcase_PBmodel
9+
- get_hybridcase_MLapplicator
10+
- get_hybridcase_PBmodel
1011
optionally
1112
- gen_hybridcase_synthetic
1213
- get_hybridcase_FloatType (if it should differ from Float32)
@@ -20,6 +21,16 @@ Provide tuple of templates of ComponentVectors `θP` and `θM`.
2021
"""
2122
function get_hybridcase_par_templates end
2223

24+
25+
"""
26+
get_hybridcase_transforms(::AbstractHybridCase; scenario)
27+
28+
Return a NamedTupe of
29+
- `transP`: Bijectors.Transform for the global PBM parameters, θP
30+
- `transM`: Bijectors.Transform for the single-site PBM parameters, θM
31+
"""
32+
function get_hybridcase_transforms end
33+
2334
"""
2435
get_hybridcase_par_templates(::AbstractHybridCase; scenario)
2536
@@ -32,7 +43,7 @@ Provide a NamedTuple of number of
3243
function get_hybridcase_sizes end
3344

3445
"""
35-
gen_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=())
46+
get_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=())
3647
3748
Construct the machine learning model fro given problem case and ML-Framework and
3849
scenario.
@@ -44,10 +55,10 @@ returns a Tuple of
4455
- AbstractModelApplicator
4556
- initial parameter vector
4657
"""
47-
function gen_hybridcase_MLapplicator end
58+
function get_hybridcase_MLapplicator end
4859

4960
"""
50-
gen_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=())
61+
get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=())
5162
5263
Construct the process-based model function
5364
`f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)`
@@ -60,7 +71,7 @@ returns a tuple of predictions with components
6071
- first, those that are constant across sites
6172
- second, those that vary across sites, with a column for each site
6273
"""
63-
function gen_hybridcase_PBmodel end
74+
function get_hybridcase_PBmodel end
6475

6576
"""
6677
gen_hybridcase_synthetic(::AbstractHybridCase, rng; scenario)
@@ -84,3 +95,5 @@ Determine the FloatType for given Case and scenario, defaults to Float32
8495
function get_hybridcase_FloatType(::AbstractHybridCase; scenario)
8596
return Float32
8697
end
98+
99+

src/init_hybrid_params.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Returns a NamedTuple of
1212
1313
# Arguments
1414
- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
15-
- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator`
15+
- `ϕg`: vector of parameters to optimize, as returned by `get_hybridcase_MLapplicator`
1616
- `n_batch`: the number of sites to predicted in each mini-batch
1717
- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
1818
parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml
1313
@time @safetestset "test_logden_normal" include("test_logden_normal.jl")
1414
#@safetestset "test" include("test/test_doubleMM.jl")
1515
@time @safetestset "test_doubleMM" include("test_doubleMM.jl")
16+
#@safetestset "test" include("test/test_HybridProblem.jl")
17+
@time @safetestset "test_HybridProblem" include("test_HybridProblem.jl")
1618
#@safetestset "test" include("test/test_cholesky_structure.jl")
1719
@time @safetestset "test_cholesky_structure" include("test_cholesky_structure.jl")
1820
#@safetestset "test" include("test/test_sample_zeta.jl")

0 commit comments

Comments
 (0)