@@ -12,41 +12,38 @@ Returns a NamedTuple of
1212
1313# Arguments
1414- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
15+ - `cor_ends`: NamedTuple with entries, `P`, and `M`, respectively with
16+ integer vectors of ending columns of parameters blocks
1517- `ϕg`: vector of parameters to optimize, as returned by `get_hybridproblem_MLapplicator`
1618- `n_batch`: the number of sites to predicted in each mini-batch
1719- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
1820 parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.
1921 Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ),
2022 because this direction is used much more often.
23+ - `ϕunc0` initial uncertainty parameters, ComponentVector wiht format of `init_hybrid_ϕunc.`
2124"""
22- function init_hybrid_params (θP, θM, cor_ends:: NamedTuple , ϕg, n_batch;
23- transP= elementwise (identity), transM= elementwise (identity))
25+ function init_hybrid_params (θP:: AbstractVector{FT} , θM:: AbstractVector{FT} ,
26+ cor_ends:: NamedTuple , ϕg:: AbstractVector{FT} , n_batch;
27+ transP = elementwise (identity), transM = elementwise (identity),
28+ ϕunc0 = init_hybrid_ϕunc (cor_ends, zero (FT))) where {FT}
2429 n_θP = length (θP)
2530 n_θM = length (θM)
31+ @assert cor_ends. P[end ] == n_θP
32+ @assert cor_ends. M[end ] == n_θM
2633 n_ϕg = length (ϕg)
2734 # check translating parameters - can match length?
2835 _ = Bijectors. inverse (transP)(θP)
2936 _ = Bijectors. inverse (transM)(θM)
30- FT = eltype (θM)
31- # zero correlation matrices
32- # ρsP = zeros(FT, sum(1:(n_θP - 1)))
33- # ρsM = zeros(FT, sum(1:(n_θM - 1)))
34- ρsP = zeros (FT, get_cor_count (cor_ends. P))
35- ρsM = zeros (FT, get_cor_count (cor_ends. M))
36- ϕunc0 = CA. ComponentVector (;
37- logσ2_logP = fill (FT (- 10.0 ), n_θP),
38- coef_logσ2_logMs = reduce (hcat, (FT[- 10.0 , 0.0 ] for _ in 1 : n_θM)),
39- ρsP,
40- ρsM)
4137 ϕ = CA. ComponentVector (;
42- μP = apply_preserve_axes (inverse (transP),θP),
38+ μP = apply_preserve_axes (inverse (transP), θP),
4339 ϕg = ϕg,
44- unc = ϕunc0);
40+ unc = ϕunc0)
4541 #
46- get_transPMs = let transP= transP, transM= transM, n_θP= n_θP, n_θM= n_θM
42+ get_transPMs = let transP = transP, transM = transM, n_θP = n_θP, n_θM = n_θM
4743 function get_transPMs_inner (n_site)
4844 transMs = ntuple (i -> transM, n_site)
49- ranges = vcat ([1 : n_θP], [(n_θP + i0* n_θM) .+ (1 : n_θM) for i0 in 0 : (n_site- 1 )])
45+ ranges = vcat (
46+ [1 : n_θP], [(n_θP + i0 * n_θM) .+ (1 : n_θM) for i0 in 0 : (n_site - 1 )])
5047 transPMs = Stacked ((transP, transMs... ), ranges)
5148 transPMs
5249 end
@@ -56,37 +53,54 @@ function init_hybrid_params(θP, θM, cor_ends::NamedTuple, ϕg, n_batch;
5653 # inv_trans_gu = Stacked(
5754 # (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges))
5855 # ϕ = inv_trans_gu(CA.getdata(ϕt))
59- get_ca_int_PMs = let
56+ get_ca_int_PMs = let
6057 function get_ca_int_PMs_inner (n_site)
61- ComponentArrayInterpreter (CA. ComponentVector (; P= θP,
62- Ms = CA. ComponentMatrix (
63- zeros (n_θM, n_site), first (CA. getaxes (θM)), CA. Axis (i = 1 : n_site))))
58+ ComponentArrayInterpreter (CA. ComponentVector (; P = θP,
59+ Ms = CA. ComponentMatrix (
60+ zeros (n_θM, n_site), first (CA. getaxes (θM)), CA. Axis (i = 1 : n_site))))
6461 end
65-
6662 end
6763 interpreters = map (get_concrete,
68- (;
69- μP_ϕg_unc = ComponentArrayInterpreter (ϕ),
70- PMs = get_ca_int_PMs (n_batch),
71- unc = ComponentArrayInterpreter (ϕunc0)
72- ))
73- (;ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs)
64+ (;
65+ μP_ϕg_unc = ComponentArrayInterpreter (ϕ),
66+ PMs = get_ca_int_PMs (n_batch),
67+ unc = ComponentArrayInterpreter (ϕunc0)
68+ ))
69+ (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs)
7470end
7571
76- function init_hybrid_ϕunc (logσ2_logP:: AbstractVector{FT} , coef_logσ2_logMs, cor_ends;
77- ρ0 = zeros (FT)) where FT
78-
79- n_θP = length (θP)
80- n_θM = length (θM)
81- n_ϕg = length (ϕg)
82- # TODO zero correlation matrices
83- ρsP = zeros (FT, sum (1 : (n_θP - 1 )))
84- ρsM = zeros (FT, sum (1 : (n_θM - 1 )))
85- ϕunc0 = CA. ComponentVector (;
86- logσ2_logP = fill (FT (- 10.0 ), n_θP),
87- coef_logσ2_logMs = reduce (hcat, (FT[- 10.0 , 0.0 ] for _ in 1 : n_θM)),
88- ρsP,
89- ρsM)
90- end
91-
92-
72+ """
73+ init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_logP, coef_logσ2_logMs, ρsP, ρsM)
74+
75+ Initialize vector of additional parameter of the approximate posterior.
76+
77+ Arguments:
78+ - `cor_ends`: NamedTuple with entries, `P`, and `M`, respectively with
79+ integer vectors of ending columns of parameters blocks
80+ - `ρ0`: default entry for ρsP and ρsM, defaults = 0f0.
81+ - `coef_logσ2_logM`: default column for `coef_logσ2_logMs`, defaults to `[-10.0, 0.0]`
82+
83+ Returns a `ComponentVector` of
84+ - `logσ2_logP`: vector of log-variances of ζP (on log scale).
85+ defaults to -10
86+ - `coef_logσ2_logMs`: offset and slope for the log-variances of ζM scaling with
87+ its value given by columns for each parameter in ζM, defaults to `[-10, 0]`
88+ - `ρsP` and `ρsM`: parameterization of the upper triangular cholesky factor
89+ of the correlation matrices of ζP and ζM, default to all entries `ρ0`, which defaults to zero.
90+ """
91+ function init_hybrid_ϕunc (
92+ cor_ends:: NamedTuple ,
93+ ρ0:: FT = 0.0f0 ,
94+ coef_logσ2_logM:: AbstractVector{FT} = FT[- 10.0 , 0.0 ];
95+ logσ2_logP:: AbstractVector{FT} = fill (FT (- 10.0 ), cor_ends. P[end ]),
96+ coef_logσ2_logMs:: AbstractMatrix{FT} = reduce (
97+ hcat, (coef_logσ2_logM for _ in 1 : cor_ends. M[end ])),
98+ ρsP = fill (ρ0, get_cor_count (cor_ends. P)),
99+ ρsM = fill (ρ0, get_cor_count (cor_ends. M)),
100+ ) where {FT}
101+ CA. ComponentVector (;
102+ logσ2_logP,
103+ coef_logσ2_logMs,
104+ ρsP,
105+ ρsM)
106+ end
0 commit comments