Skip to content

Commit 9e18e72

Browse files
committed
allow specifying initial uncertainty by init_hybrid_ϕunc
1 parent f3fb039 commit 9e18e72

File tree

7 files changed

+71
-57
lines changed

7 files changed

+71
-57
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ scatterplot(θMs_true[2,:], θMs[2,:])
5555
prob1o.θP
5656
scatterplot(vec(y_true), vec(y_pred))
5757

58-
# still overestimating θMs
58+
# still overestimating θMs and θP
5959

6060
() -> begin # with more iterations?
6161
prob2 = prob1o

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ include("cholesky.jl")
6060
export neg_elbo_transnorm_gf, predict_gf
6161
include("elbo.jl")
6262

63-
export init_hybrid_params
63+
export init_hybrid_params, init_hybrid_ϕunc
6464
include("init_hybrid_params.jl")
6565

6666
export AbstractHybridSolver, HybridPointSolver, HybridPosteriorSolver

src/cholesky.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,20 +290,19 @@ end
290290
Return number of correlation coefficients for a correlation matrix of size `(npar x npar)`
291291
With blocks starting a postions given with tuple `cor_ends`.
292292
"""
293-
function get_cor_count(cor_ends)
293+
function get_cor_count(cor_ends::AbstractVector)
294294
sum(get_cor_counts(cor_ends))
295295
end
296296
function get_cor_counts(cor_ends::AbstractVector{T}) where {T}
297297
isempty(cor_ends) && return (zero(T))
298298
cnt_blocks = (
299299
begin
300-
cor_start = i == 1 ? one(T) : cor_ends[i-1] + one(T)
301-
cor_ends[i] - cor_start
300+
i == 1 ? cor_ends[i] : cor_ends[i] - cor_ends[i-1]
302301
end for i in 1:length(cor_ends)
303302
)
304-
sumn.(cnt_blocks)
303+
get_cor_count.(cnt_blocks)
305304
end
306-
function get_cor_count(n_par::T) where {T<:Integer}
305+
function get_cor_count(n_par::T) where T<:Number # <: Integer causes problems with AD
307306
sumn(n_par - one(T))
308307
end
309308

@@ -318,7 +317,7 @@ E.g. For a matrix with a 3x3, a 2x2, and another single-entry block,
318317
the blocks start at columns (3,5,6). It defaults to a single entire block.
319318
"""
320319
function transformU_block_cholesky1(
321-
v::AbstractVector{T}, cor_ends::AbstractVector{IT}=Int[]) where {T,IT<:Integer}
320+
v::AbstractVector{T}, cor_ends::AbstractVector{TI}=Int[]) where {T,TI<:Integer}
322321
#@show v, cor_ends
323322
if length(cor_ends) <= 1 # if there is only one block, return it
324323
return transformU_cholesky1(v)
@@ -327,7 +326,7 @@ function transformU_block_cholesky1(
327326
#@show cor_counts
328327
ranges = ChainRulesCore.@ignore_derivatives (
329328
begin
330-
cor_start = (i == 1 ? 1 : cor_counts[i-1] + one(IT))
329+
cor_start = (i == 1 ? one(TI) : cor_counts[i-1] + one(TI))
331330
cor_start:cor_counts[i]
332331
end for i in 1:length(cor_counts)
333332
)

src/init_hybrid_params.jl

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,38 @@ Returns a NamedTuple of
1212
1313
# Arguments
1414
- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
15+
- `cor_ends`: NamedTuple with entries, `P`, and `M`, respectively with
16+
integer vectors of ending columns of parameters blocks
1517
- `ϕg`: vector of parameters to optimize, as returned by `get_hybridproblem_MLapplicator`
1618
- `n_batch`: the number of sites to predicted in each mini-batch
1719
- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
1820
parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.
1921
Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ),
2022
because this direction is used much more often.
23+
- `ϕunc0` initial uncertainty parameters, ComponentVector wiht format of `init_hybrid_ϕunc.`
2124
"""
22-
function init_hybrid_params(θP, θM, cor_ends::NamedTuple, ϕg, n_batch;
23-
transP=elementwise(identity), transM=elementwise(identity))
25+
function init_hybrid_params(θP::AbstractVector{FT}, θM::AbstractVector{FT},
26+
cor_ends::NamedTuple, ϕg::AbstractVector{FT}, n_batch;
27+
transP = elementwise(identity), transM = elementwise(identity),
28+
ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT))) where {FT}
2429
n_θP = length(θP)
2530
n_θM = length(θM)
31+
@assert cor_ends.P[end] == n_θP
32+
@assert cor_ends.M[end] == n_θM
2633
n_ϕg = length(ϕg)
2734
# check translating parameters - can match length?
2835
_ = Bijectors.inverse(transP)(θP)
2936
_ = Bijectors.inverse(transM)(θM)
30-
FT = eltype(θM)
31-
# zero correlation matrices
32-
# ρsP = zeros(FT, sum(1:(n_θP - 1)))
33-
# ρsM = zeros(FT, sum(1:(n_θM - 1)))
34-
ρsP = zeros(FT, get_cor_count(cor_ends.P))
35-
ρsM = zeros(FT, get_cor_count(cor_ends.M))
36-
ϕunc0 = CA.ComponentVector(;
37-
logσ2_logP = fill(FT(-10.0), n_θP),
38-
coef_logσ2_logMs = reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)),
39-
ρsP,
40-
ρsM)
4137
ϕ = CA.ComponentVector(;
42-
μP = apply_preserve_axes(inverse(transP),θP),
38+
μP = apply_preserve_axes(inverse(transP), θP),
4339
ϕg = ϕg,
44-
unc = ϕunc0);
40+
unc = ϕunc0)
4541
#
46-
get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM
42+
get_transPMs = let transP = transP, transM = transM, n_θP = n_θP, n_θM = n_θM
4743
function get_transPMs_inner(n_site)
4844
transMs = ntuple(i -> transM, n_site)
49-
ranges = vcat([1:n_θP], [(n_θP + i0*n_θM) .+ (1:n_θM) for i0 in 0:(n_site-1)])
45+
ranges = vcat(
46+
[1:n_θP], [(n_θP + i0 * n_θM) .+ (1:n_θM) for i0 in 0:(n_site - 1)])
5047
transPMs = Stacked((transP, transMs...), ranges)
5148
transPMs
5249
end
@@ -56,37 +53,54 @@ function init_hybrid_params(θP, θM, cor_ends::NamedTuple, ϕg, n_batch;
5653
# inv_trans_gu = Stacked(
5754
# (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges))
5855
# ϕ = inv_trans_gu(CA.getdata(ϕt))
59-
get_ca_int_PMs = let
56+
get_ca_int_PMs = let
6057
function get_ca_int_PMs_inner(n_site)
61-
ComponentArrayInterpreter(CA.ComponentVector(; P=θP,
62-
Ms = CA.ComponentMatrix(
63-
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site))))
58+
ComponentArrayInterpreter(CA.ComponentVector(; P = θP,
59+
Ms = CA.ComponentMatrix(
60+
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site))))
6461
end
65-
6662
end
6763
interpreters = map(get_concrete,
68-
(;
69-
μP_ϕg_unc = ComponentArrayInterpreter(ϕ),
70-
PMs = get_ca_int_PMs(n_batch),
71-
unc = ComponentArrayInterpreter(ϕunc0)
72-
))
73-
(;ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs)
64+
(;
65+
μP_ϕg_unc = ComponentArrayInterpreter(ϕ),
66+
PMs = get_ca_int_PMs(n_batch),
67+
unc = ComponentArrayInterpreter(ϕunc0)
68+
))
69+
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs)
7470
end
7571

76-
function init_hybrid_ϕunc(logσ2_logP::AbstractVector{FT}, coef_logσ2_logMs, cor_ends;
77-
ρ0 = zeros(FT)) where FT
78-
79-
n_θP = length(θP)
80-
n_θM = length(θM)
81-
n_ϕg = length(ϕg)
82-
# TODO zero correlation matrices
83-
ρsP = zeros(FT, sum(1:(n_θP - 1)))
84-
ρsM = zeros(FT, sum(1:(n_θM - 1)))
85-
ϕunc0 = CA.ComponentVector(;
86-
logσ2_logP = fill(FT(-10.0), n_θP),
87-
coef_logσ2_logMs = reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)),
88-
ρsP,
89-
ρsM)
90-
end
91-
92-
72+
"""
73+
init_hybrid_ϕunc(cor_ends, ρ0=0f0; logσ2_logP, coef_logσ2_logMs, ρsP, ρsM)
74+
75+
Initialize vector of additional parameter of the approximate posterior.
76+
77+
Arguments:
78+
- `cor_ends`: NamedTuple with entries, `P`, and `M`, respectively with
79+
integer vectors of ending columns of parameters blocks
80+
- `ρ0`: default entry for ρsP and ρsM, defaults = 0f0.
81+
- `coef_logσ2_logM`: default column for `coef_logσ2_logMs`, defaults to `[-10.0, 0.0]`
82+
83+
Returns a `ComponentVector` of
84+
- `logσ2_logP`: vector of log-variances of ζP (on log scale).
85+
defaults to -10
86+
- `coef_logσ2_logMs`: offset and slope for the log-variances of ζM scaling with
87+
its value given by columns for each parameter in ζM, defaults to `[-10, 0]`
88+
- `ρsP` and `ρsM`: parameterization of the upper triangular cholesky factor
89+
of the correlation matrices of ζP and ζM, default to all entries `ρ0`, which defaults to zero.
90+
"""
91+
function init_hybrid_ϕunc(
92+
cor_ends::NamedTuple,
93+
ρ0::FT = 0.0f0,
94+
coef_logσ2_logM::AbstractVector{FT} = FT[-10.0, 0.0];
95+
logσ2_logP::AbstractVector{FT} = fill(FT(-10.0), cor_ends.P[end]),
96+
coef_logσ2_logMs::AbstractMatrix{FT} = reduce(
97+
hcat, (coef_logσ2_logM for _ in 1:cor_ends.M[end])),
98+
ρsP = fill(ρ0, get_cor_count(cor_ends.P)),
99+
ρsM = fill(ρ0, get_cor_count(cor_ends.M)),
100+
) where {FT}
101+
CA.ComponentVector(;
102+
logσ2_logP,
103+
coef_logσ2_logMs,
104+
ρsP,
105+
ρsM)
106+
end

src/util_ca.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function cpu_ca end
77
# define in FluxExt
88

99
function apply_preserve_axes(f, ca::CA.ComponentArray)
10-
CA.ComponentArray(f(ca), CA.getaxes(ca))
10+
CA.ComponentArray(f(CA.getdata(ca)), CA.getaxes(ca))
1111
end
1212

1313

test/test_cholesky_structure.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end;
162162
U = CP.transformU_block_cholesky1(v, cor_ends)
163163
@test diag(U' * U) ones(4)
164164
@test U[1:3, 4:4] zeros(3, 1)
165-
gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_ends)), v)[1] # works nice
165+
gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_ends)), v)[1]; # works nice
166166
# degenerate case of no correlations
167167
vc0 = CA.ComponentVector{Float32}()
168168
cor_ends0 = get_ca_ends(vc0)
@@ -171,7 +171,7 @@ end;
171171
#collect(ns)
172172
U = CP.transformU_block_cholesky1(CA.getdata(ρ0), cor_ends0)
173173
@test diag(U) == [1f0]
174-
gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(ρ0, cor_ends0)), v)[1] # works nice
174+
gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(ρ0, cor_ends0)), v)[1]; # works nice
175175

176176
if CUDA.functional() # only run the test, if CUDA is working (not on Github ci)
177177
vc = v_orig = CA.ComponentVector(b1 = CuArray(1.0f0:3.0f0), b2 = CuArray([5.0f0]))

test/test_elbo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ cor_ends = get_hybridproblem_cor_ends(prob; scenario)
3939
# transP = elementwise(exp)
4040
# transM = Stacked(elementwise(identity), elementwise(exp))
4141
#transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch
42+
ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT))
4243
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
4344
θP_true, θMs_true[:, 1], cor_ends, ϕg0, n_batch; transP, transM);
4445
ϕ_ini = ϕ

0 commit comments

Comments
 (0)