@@ -14,12 +14,19 @@ Returns a NamedTuple of
1414- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
1515- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator`
1616- `n_batch`: the number of sites to predicted in each mini-batch
17- - `transP`, `transM`: the Transformations for the global and site-dependent parameters
17+ - `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
18+ parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.
19+ Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ),
20+ because this direction is used much more often.
1821"""
19- function init_hybrid_params (θP, θM, ϕg, n_batch; transP= asℝ, transM= asℝ)
22+ function init_hybrid_params (θP, θM, ϕg, n_batch;
23+ transP= elementwise (identity), transM= elementwise (identity))
2024 n_θP = length (θP)
2125 n_θM = length (θM)
2226 n_ϕg = length (ϕg)
27+ # check translating parameters - can match length?
28+ _ = Bijectors. inverse (transP)(θP)
29+ _ = Bijectors. inverse (transM)(θM)
2330 # zero correlation matrices
2431 ρsP = zeros (sum (1 : (n_θP - 1 )))
2532 ρsM = zeros (sum (1 : (n_θM - 1 )))
@@ -28,39 +35,35 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
2835 coef_logσ2_logMs = reduce (hcat, ([- 10.0 , 0.0 ] for _ in 1 : n_θM)),
2936 ρsP,
3037 ρsM)
31- ϕt = CA. ComponentVector (;
32- μP = θP ,
38+ ϕ = CA. ComponentVector (;
39+ μP = inverse (transP)(θP) ,
3340 ϕg = ϕg,
3441 unc = ϕunc0);
3542 #
3643 get_transPMs = let transP= transP, transM= transM, n_θP= n_θP, n_θM= n_θM
3744 function get_transPMs_inner (n_site)
38- transPMs = as (
39- (P = as (Array, transP, n_θP),
40- Ms = as (Array, transM, n_θM, n_site)))
45+ transMs = ntuple (i -> transM, n_site)
46+ ranges = vcat ([1 : n_θP], [(n_θP + i0* n_θM) .+ (1 : n_θM) for i0 in 0 : (n_site- 1 )])
47+ transPMs = Stacked ((transP, transMs... ), ranges)
48+ transPMs
4149 end
4250 end
4351 transPMs_batch = get_transPMs (n_batch)
44- trans_gu = as (
45- (μP = as (Array, asℝ₊, n_θP),
46- ϕg = as (Array, n_ϕg),
47- unc = as (Array, length (ϕunc0))))
48- ϕ = inverse_ca (trans_gu, ϕt)
49- # trans_g = as(
50- # (μP = as(Array, asℝ₊, n_θP),
51- # ϕg = as(Array, n_ϕg)))
52- #
52+ # ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0)))
53+ # inv_trans_gu = Stacked(
54+ # (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges))
55+ # ϕ = inv_trans_gu(CA.getdata(ϕt))
5356 get_ca_int_PMs = let
5457 function get_ca_int_PMs_inner (n_site)
55- ComponentArrayInterpreter (CA. ComponentVector (; θP,
56- θMs = CA. ComponentMatrix (
58+ ComponentArrayInterpreter (CA. ComponentVector (; P = θP,
59+ Ms = CA. ComponentMatrix (
5760 zeros (n_θM, n_site), first (CA. getaxes (θM)), CA. Axis (i = 1 : n_site))))
5861 end
5962
6063 end
6164 interpreters = map (get_concrete,
6265 (;
63- μP_ϕg_unc = ComponentArrayInterpreter (ϕt ),
66+ μP_ϕg_unc = ComponentArrayInterpreter (ϕ ),
6467 PMs = get_ca_int_PMs (n_batch),
6568 unc = ComponentArrayInterpreter (ϕunc0)
6669 ))
0 commit comments