11struct DoubleMMCase <: AbstractHybridCase end
22
3- const S1 = [1.0 , 1.0 , 1.0 , 1.0 , 0.4 , 0.3 , 0.1 ]
4- const S2 = [1.0 , 3.0 , 4.0 , 5.0 , 5.0 , 5.0 , 5.0 ]
53
6- θP = CA. ComponentVector (r0 = 0.3 , K2 = 2.0 )
7- θM = CA. ComponentVector (r1 = 0.5 , K1 = 0.2 )
4+ θP = CA. ComponentVector {Float32} (r0 = 0.3 , K2 = 2.0 )
5+ θM = CA. ComponentVector {Float32} (r1 = 0.5 , K1 = 0.2 )
6+
7+ transP = elementwise (exp)
8+ transM = Stacked (elementwise (identity), elementwise (exp))
9+
810
911const int_θdoubleMM = ComponentArrayInterpreter (flatten1 (CA. ComponentVector (; θP, θM)))
1012
11- function f_doubleMM (θ:: AbstractVector )
13+ function f_doubleMM (θ:: AbstractVector , x )
1214 # extract parameters not depending on order, i.e whether they are in θP or θM
1315 θc = int_θdoubleMM (θ)
1416 r0, r1, K1, K2 = θc[(:r0 , :r1 , :K1 , :K2 )]
15- y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
17+ y = r0 .+ r1 .* x . S1 ./ (K1 .+ x . S1) .* x . S2 ./ (K2 .+ x . S2)
1618 return (y)
1719end
1820
19- function HybridVariationalInference . get_hybridcase_par_templates (:: DoubleMMCase ; scenario:: NTuple = ())
21+ function HVI . get_hybridcase_par_templates (:: DoubleMMCase ; scenario:: NTuple = ())
2022 (; θP, θM)
2123end
2224
23- function HybridVariationalInference. get_hybridcase_sizes (:: DoubleMMCase ; scenario = ())
25+ function HVI. get_hybridcase_transforms (:: AbstractHybridCase ; scenario:: NTuple = ())
26+ (; transP, transM)
27+ end
28+
29+ function HVI. get_hybridcase_sizes (:: DoubleMMCase ; scenario = ())
2430 n_covar_pc = 2
2531 n_covar = n_covar_pc + 3 # linear dependent
26- n_site = 10 ^ n_covar_pc
32+ # n_site = 10^n_covar_pc
2733 n_batch = 10
2834 n_θM = length (θM)
2935 n_θP = length (θP)
30- (; n_covar, n_site, n_batch, n_θM, n_θP)
36+ # (; n_covar, n_site, n_batch, n_θM, n_θP)
37+ (; n_covar, n_batch, n_θM, n_θP)
3138end
3239
33- function HybridVariationalInference . gen_hybridcase_PBmodel (:: DoubleMMCase ; scenario:: NTuple = ())
34- fsite = (θ, x_site) -> f_doubleMM (θ) # omit x_site drivers
40+ function HVI . get_hybridcase_PBmodel (:: DoubleMMCase ; scenario:: NTuple = ())
41+ # fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
3542 function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , x)
36- pred_sites = applyf (fsite , θMs, θP, x)
43+ pred_sites = applyf (f_doubleMM , θMs, θP, x)
3744 pred_global = eltype (pred_sites)[]
3845 return pred_global, pred_sites
3946 end
4047end
4148
42- function HybridVariationalInference . get_hybridcase_FloatType (:: DoubleMMCase ; scenario)
43- return Float32
44- end
49+ # function HVI .get_hybridcase_FloatType(::DoubleMMCase; scenario)
50+ # return Float32
51+ # end
4552
46- function HybridVariationalInference. gen_hybridcase_synthetic (case:: DoubleMMCase , rng:: AbstractRNG ;
53+ const xP_S1 = Float32[1.0 , 1.0 , 1.0 , 1.0 , 0.4 , 0.3 , 0.1 ]
54+ const xP_S2 = Float32[1.0 , 3.0 , 4.0 , 5.0 , 5.0 , 5.0 , 5.0 ]
55+
56+ function HVI. gen_hybridcase_synthetic (case:: DoubleMMCase , rng:: AbstractRNG ;
4757 scenario = ())
4858 n_covar_pc = 2
49- (; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
59+ n_site = 200
60+ (; n_covar, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
5061 FloatType = get_hybridcase_FloatType (case; scenario)
5162 xM, θMs_true0 = gen_cov_pred (rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
5263 rhodec = 8 , is_using_dropout = false )
5364 int_θMs_sites = ComponentArrayInterpreter (θM, (n_site,))
5465 # normalize to be distributed around the prescribed true values
55- θMs_true = int_θMs_sites (scale_centered_at (θMs_true0, θM, 0.1 ))
56- f = gen_hybridcase_PBmodel (case; scenario)
57- xP = fill ((), n_site)
58- y_global_true, y_true = f (θP, θMs_true, zip () )
59- σ_o = 0.01
66+ θMs_true = int_θMs_sites (scale_centered_at (θMs_true0, θM, FloatType ( 0.1 ) ))
67+ f = get_hybridcase_PBmodel (case; scenario)
68+ xP = fill ((;S1 = xP_S1, S2 = xP_S2 ), n_site)
69+ y_global_true, y_true = f (θP, θMs_true, xP )
70+ σ_o = FloatType ( 0.01 )
6071 # σ_o = 0.002
61- y_global_o = y_global_true .+ randn (rng, size (y_global_true)) .* σ_o
62- y_o = y_true .+ randn (rng, size (y_true)) .* σ_o
72+ y_global_o = y_global_true .+ randn (rng, FloatType, size (y_global_true)) .* σ_o
73+ y_o = y_true .+ randn (rng, FloatType, size (y_true)) .* σ_o
6374 (;
6475 xM,
76+ n_site,
6577 θP_true = θP,
6678 θMs_true,
6779 xP,
@@ -72,3 +84,6 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase,
7284 σ_o = fill (σ_o, size (y_true,1 )),
7385 )
7486end
87+
88+
89+
0 commit comments