@@ -6,6 +6,10 @@ const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
66θP = CA. ComponentVector (r0 = 0.3 , K2 = 2.0 )
77θM = CA. ComponentVector (r1 = 0.5 , K1 = 0.2 )
88
9+ transP = elementwise (exp)
10+ transM = Stacked (elementwise (identity), elementwise (exp))
11+
12+
913const int_θdoubleMM = ComponentArrayInterpreter (flatten1 (CA. ComponentVector (; θP, θM)))
1014
1115function f_doubleMM (θ:: AbstractVector )
@@ -16,21 +20,26 @@ function f_doubleMM(θ::AbstractVector)
1620 return (y)
1721end
1822
19- function HybridVariationalInference . get_hybridcase_par_templates (:: DoubleMMCase ; scenario:: NTuple = ())
23+ function HVI . get_hybridcase_par_templates (:: DoubleMMCase ; scenario:: NTuple = ())
2024 (; θP, θM)
2125end
2226
23- function HybridVariationalInference. get_hybridcase_sizes (:: DoubleMMCase ; scenario = ())
27+ function HVI. get_hybridcase_transforms (:: AbstractHybridCase ; scenario:: NTuple = ())
28+ (; transP, transM)
29+ end
30+
31+ function HVI. get_hybridcase_sizes (:: DoubleMMCase ; scenario = ())
2432 n_covar_pc = 2
2533 n_covar = n_covar_pc + 3 # linear dependent
26- n_site = 10 ^ n_covar_pc
34+ # n_site = 10^n_covar_pc
2735 n_batch = 10
2836 n_θM = length (θM)
2937 n_θP = length (θP)
30- (; n_covar, n_site, n_batch, n_θM, n_θP)
38+ # (; n_covar, n_site, n_batch, n_θM, n_θP)
39+ (; n_covar, n_batch, n_θM, n_θP)
3140end
3241
33- function HybridVariationalInference . gen_hybridcase_PBmodel (:: DoubleMMCase ; scenario:: NTuple = ())
42+ function HVI . get_hybridcase_PBmodel (:: DoubleMMCase ; scenario:: NTuple = ())
3443 fsite = (θ, x_site) -> f_doubleMM (θ) # omit x_site drivers
3544 function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , x)
3645 pred_sites = applyf (fsite, θMs, θP, x)
@@ -39,21 +48,22 @@ function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scena
3948 end
4049end
4150
42- function HybridVariationalInference . get_hybridcase_FloatType (:: DoubleMMCase ; scenario)
51+ function HVI . get_hybridcase_FloatType (:: DoubleMMCase ; scenario)
4352 return Float32
4453end
4554
46- function HybridVariationalInference . gen_hybridcase_synthetic (case:: DoubleMMCase , rng:: AbstractRNG ;
55+ function HVI . gen_hybridcase_synthetic (case:: DoubleMMCase , rng:: AbstractRNG ;
4756 scenario = ())
4857 n_covar_pc = 2
49- (; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
58+ n_site = 200
59+ (; n_covar, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
5060 FloatType = get_hybridcase_FloatType (case; scenario)
5161 xM, θMs_true0 = gen_cov_pred (rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
5262 rhodec = 8 , is_using_dropout = false )
5363 int_θMs_sites = ComponentArrayInterpreter (θM, (n_site,))
5464 # normalize to be distributed around the prescribed true values
5565 θMs_true = int_θMs_sites (scale_centered_at (θMs_true0, θM, 0.1 ))
56- f = gen_hybridcase_PBmodel (case; scenario)
66+ f = get_hybridcase_PBmodel (case; scenario)
5767 xP = fill ((), n_site)
5868 y_global_true, y_true = f (θP, θMs_true, zip ())
5969 σ_o = 0.01
@@ -62,6 +72,7 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase,
6272 y_o = y_true .+ randn (rng, size (y_true)) .* σ_o
6373 (;
6474 xM,
75+ n_site,
6576 θP_true = θP,
6677 θMs_true,
6778 xP,
0 commit comments