Skip to content

Commit 151551d

Browse files
committed
enable independent subranges in parameter s
1 parent 241e54d commit 151551d

File tree

10 files changed

+272
-106
lines changed

10 files changed

+272
-106
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_tru
282282

283283
# look at θP, θM1 of first site
284284
intm_PMs_gen = get_ca_int_PMs(n_site)
285-
ζs, _σ = HVI.generate_ζ(rng, g_flux, f, res.u, xM_gpu,
285+
ζs, _σ = HVI.generate_ζ(rng, g_flux, res.u, xM_gpu,
286286
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred);
287287
ζs = ζs |> Flux.cpu;
288288
θPM = vcat(θP_true, θMs_true[:, 1])

src/HybridProblem.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
struct HybridProblem <: AbstractHybridCase
22
θP
33
θM
4+
f
5+
g
6+
ϕg
47
transP
58
transM
9+
cor_starts # = (P=(1,),M=(1,))
610
n_covar
711
n_batch
8-
f
9-
g
10-
ϕg
1112
train_loader
1213
# inner constructor to constrain the types
1314
function HybridProblem(
@@ -17,15 +18,20 @@ struct HybridProblem <: AbstractHybridCase
1718
transM::Union{Function, Bijectors.Transform},
1819
transP::Union{Function, Bijectors.Transform},
1920
n_covar::Integer, n_batch::Integer,
20-
train_loader::DataLoader)
21-
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader)
21+
train_loader::DataLoader,
22+
cor_starts = (P=(1,), M=(1,)))
23+
new(θP, θM, f, g, ϕg, transM, transP, cor_starts, n_covar, n_batch, train_loader)
2224
end
2325
end
2426

2527
function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ())
2628
(; θP = prob.θP, θM = prob.θM)
2729
end
2830

31+
function get_hybridcase_transforms(prob::HybridProblem; scenario::NTuple = ())
32+
(; transP = prob.transP, transM = prob.transM)
33+
end
34+
2935
function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ())
3036
n_θM = length(prob.θM)
3137
n_θP = length(prob.θP)
@@ -46,6 +52,9 @@ function get_hybridcase_train_dataloader(
4652
return(prob.train_loader)
4753
end
4854

55+
function get_hybridcase_cor_starts(prob::HybridProblem; scenario = ())
56+
prob.cor_starts
57+
end
4958

5059
# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
5160
# eltype(prob.θM)

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ include("util_opt.jl")
4343
export neg_logden_indep_normal, entropy_MvNormal
4444
include("logden_normal.jl")
4545

46-
#export - all internal
46+
export get_ca_starts
4747
include("cholesky.jl")
4848

4949
export neg_elbo_transnorm_gf, predict_gf

src/cholesky.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,55 @@ function transformU_cholesky1(v::GPUArraysCore.AbstractGPUVector; n=invsumn(leng
252252
return U
253253
end
254254

255+
# function transformU_block_cholesky1(v::CA.ComponentVector;
256+
# ns=(invsumn(length(v[k])) + 1 for k in keys(v)) # may pass for efficiency
257+
# )
258+
# blocks = [transformU_cholesky1(v[k]; n) for (k, n) in zip(keys(v), ns)]
259+
# U = _create_blockdiag(v[first(keys(v))], blocks) # v only for dispatch: plain matrix for gpu
260+
# end
261+
262+
263+
"""
264+
get_ca_starts(vc::ComponentVector)
265+
266+
Return a tuple with starting positions of components in vc.
267+
Useful for providing information on correlactions among subranges in a vector.
268+
"""
269+
function get_ca_starts(vc::CA.ComponentVector)
270+
(1, (1 .+ cumsum((length(vc[k]) for k in front(keys(vc)))))...)
271+
end
272+
"omit the last n elements of an iterator"
273+
front(itr, n=1) = Iterators.take(itr, length(itr)-n)
274+
275+
"""
276+
transformU_block_cholesky1(v::AbstractVector, cor_starts = (1,))
277+
278+
Transform a parameterization v of a blockdiagonal of upper triangular matrices
279+
into the this matrix.
280+
`cor_starts` is a NTuple of Integeres specifying the first column of each block.
281+
E.g. For a matrix with a 3x3, a 2x2, and another block,
282+
the blocks start at colums (1,4,6). It defaults to a single entire block.
283+
"""
284+
function transformU_block_cholesky1(v::AbstractVector, cor_starts = (1,))
285+
cor_starts_end = (cor_starts..., length(v)+1)
286+
ranges = ChainRulesCore.@ignore_derivatives (
287+
cor_starts_end[i]:(cor_starts_end[i+1]-1) for i in 1:length(cor_starts))
288+
blocks = [transformU_cholesky1(v[r]) for r in ranges]
289+
U = _create_blockdiag(v, blocks) # v only for dispatch: plain matrix for gpu
290+
return(U)
291+
end
292+
293+
function _create_blockdiag(::AbstractArray, blocks)
294+
BlockDiagonal(blocks)
295+
end
296+
297+
function _create_blockdiag(::GPUArraysCore.AbstractGPUArray, blocks)
298+
# impose no special structure
299+
cat(blocks...; dims=(1, 2))
300+
end
301+
302+
303+
255304
() -> begin
256305
tmp = sqrt.(sum(abs2, U_scaled, dims=1))
257306
tmp2 = sum(abs2, U_scaled, dims=1) .^ (-1 / 2)

src/elbo.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ expected value of the likelihood of observations.
2222
function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::AbstractMatrix,
2323
xP, transPMs, interpreters::NamedTuple;
2424
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(),
25-
entropyN = 0.0,
25+
cor_starts=(P=(1,),M=(1,))
2626
)
27-
ζs, σ = generate_ζ(rng, g, f, ϕ, xM, interpreters; n_MC)
27+
ζs, σ = generate_ζ(rng, g, ϕ, xM, interpreters; n_MC, cor_starts)
2828
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
2929
#ζi = first(eachcol(ζs_cpu))
3030
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
@@ -48,13 +48,14 @@ Prediction function for hybrid model. Returns an Array `(n_obs, n_site, n_sample
4848
"""
4949
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters;
5050
get_transPMs, get_ca_int_PMs, n_sample_pred=200,
51-
gpu_data_handler=get_default_GPUHandler())
51+
gpu_data_handler=get_default_GPUHandler(),
52+
cor_starts=(P=(1,),M=(1,)))
5253
n_site = size(xM, 2)
5354
intm_PMs_gen = get_ca_int_PMs(n_site)
5455
trans_PMs_gen = get_transPMs(n_site)
5556
interpreters_gen = (; interpreters..., PMs = intm_PMs_gen)
56-
ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM),
57-
interpreters_gen; n_MC = n_sample_pred)
57+
ζs, _ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM),
58+
interpreters_gen; n_MC = n_sample_pred, cor_starts)
5859
ζs_cpu = gpu_data_handler(ζs) #
5960
y_pred = stack(map-> first(predict_y(
6061
ζ, xP, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu)));
@@ -69,14 +70,14 @@ Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0`
6970
to the means extracted from parameters and predicted by the machine learning
7071
model.
7172
"""
72-
function generate_ζ(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix,
73-
interpreters::NamedTuple; n_MC=3)
73+
function generate_ζ(rng, g, ϕ::AbstractVector, xM::AbstractMatrix,
74+
interpreters::NamedTuple; n_MC=3, cor_starts=(P=(1,),M=(1,)))
7475
# see documentation of neg_elbo_transnorm_gf
7576
ϕc = interpreters.μP_ϕg_unc(CA.getdata(ϕ))
7677
μ_ζP = ϕc.μP
7778
ϕg = ϕc.ϕg
7879
μ_ζMs0 = g(xM, ϕg) # TODO provide μ_ζP to g
79-
ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC)
80+
ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_starts)
8081
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
8182
ζ = stack(map(eachcol(ζ_resid)) do r
8283
rc = interpreters.PMs(r)
@@ -98,21 +99,21 @@ ComponentMarshellers
9899
- marsh_batch(n_batch)
99100
- marsh_unc(n_UncP, n_UncM, n_UncCorr)
100101
"""
101-
function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, ϕunc::AbstractVector, args...;
102-
n_MC=3)
102+
function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix,
103+
args...; n_MC, cor_starts)
103104
n_θP, n_θMs = length(ζP), length(ζMs)
104105
urand = _create_random(rng, CA.getdata(ζP), n_θP + n_θMs, n_MC)
105-
sample_ζ_norm0(urand, ζP, ζMs, ϕunc, args...)
106+
sample_ζ_norm0(urand, ζP, ζMs, args...; cor_starts)
106107
end
107108

108109
function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::AbstractMatrix,
109-
ϕunc::AbstractVector, int_unc = ComponentArrayInterpreter(ϕunc);
110+
ϕunc::AbstractVector, int_unc = ComponentArrayInterpreter(ϕunc); cor_starts
110111
) where {T}
111112
ϕuncc = int_unc(CA.getdata(ϕunc))
112113
n_θP, n_θMs, (n_θM, n_batch) = length(ζP), length(ζMs), size(ζMs)
113114
# make sure to not create a UpperTriangular Matrix of an CuArray in transformU_cholesky1
114-
UP = transformU_cholesky1(ϕuncc.ρsP)
115-
UM = transformU_cholesky1(ϕuncc.ρsM)
115+
UP = transformU_block_cholesky1(ϕuncc.ρsP, cor_starts.P)
116+
UM = transformU_block_cholesky1(ϕuncc.ρsM, cor_starts.M)
116117
cf = ϕuncc.coef_logσ2_logMs
117118
logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs)
118119
logσ2_logP = vec(CA.getdata(ϕuncc.logσ2_logP))

src/hybrid_case.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ For a specific case, provide functions that specify details
1111
- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
1212
optionally
1313
- `gen_hybridcase_synthetic`
14-
- `get_hybridcase_FloatType` (defaults to eltype(θM))
14+
- `get_hybridcase_FloatType` (defaults to `eltype(θM)`)
15+
- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
1516
"""
1617
abstract type AbstractHybridCase end;
1718

@@ -93,7 +94,7 @@ function gen_hybridcase_synthetic end
9394
9495
Determine the FloatType for given Case and scenario, defaults to Float32
9596
"""
96-
function get_hybridcase_FloatType(case::AbstractHybridCase; scenario)
97+
function get_hybridcase_FloatType(case::AbstractHybridCase; scenario=())
9798
return eltype(get_hybridcase_par_templates(case; scenario).θM)
9899
end
99100

@@ -114,5 +115,26 @@ function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::Abstract
114115
return(train_loader)
115116
end
116117

118+
"""
119+
get_hybridcase_cor_starts(case::AbstractHybridCase; scenario)
120+
121+
Specify blocks in correlation matrices among parameters.
122+
Returns a NamedTuple.
123+
- `P`: correlations among global parameters
124+
- `M`: correlations among ML-predicted parameters
125+
126+
Subsets ofparameters that are correlated with other but not correlated with
127+
parameters of other subranges are specified by indicating the starting position
128+
of each subrange.
129+
E.g. if withing global parameter vector `(p1, p2, p3)`, `p1` and `p2` are correlated,
130+
but parameter `p3` is not correlated with them,
131+
then the first subrange starts at position 1 and the second subrange starts at position 3.
132+
If there is only single block of all ML-predicted parameters being correlated
133+
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
134+
"""
135+
function get_hybridcase_cor_starts(case::AbstractHybridCase; scenario = ())
136+
(P=(1,), M=(1,))
137+
end
138+
117139

118140

test/test_HybridProblem.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ import Zygote
1212

1313
using OptimizationOptimisers
1414

15+
1516
const MLengine = Val(nameof(SimpleChains))
1617

1718
construct_problem = () -> begin
1819
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
1920
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
2021
transP = elementwise(exp)
2122
transM = Stacked(elementwise(identity), elementwise(exp))
23+
cov_starts = (P=(1,2),M=(1)) # assume r0 independent of K2
2224
n_covar = 5
2325
n_batch = 10
2426
int_θdoubleMM = get_concrete(ComponentArrayInterpreter(
@@ -53,7 +55,7 @@ construct_problem = () -> begin
5355
# HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global,
5456
# g, ϕg, train_loader)
5557
HybridProblem(θP, θM, g_chain, f_doubleMM_with_global,
56-
transM, transP, n_covar, n_batch, train_loader)
58+
transM, transP, n_covar, n_batch, train_loader, cov_starts)
5759
end
5860
prob = construct_problem();
5961
scenario = (:default,)
@@ -93,3 +95,55 @@ scenario = (:default,)
9395
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)
9496
end
9597
end
98+
99+
() -> begin
100+
@testset "neg_elbo_transnorm_gf cpu" begin
101+
rng = StableRNG(111)
102+
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine);
103+
train_loader = get_hybridcase_train_dataloader(prob)
104+
(xM, xP, y_o) = first(train_loader)
105+
n_batch = size(y_o,2)
106+
f = get_hybridcase_PBmodel(prob)
107+
(θP0, θM0) = get_hybridcase_par_templates(prob)
108+
(; transP, transM) = get_hybridcase_transforms(prob)
109+
110+
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
111+
θP0, θM0, ϕg0, n_batch; transP, transM);
112+
ϕ_ini = ϕ
113+
114+
cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o,
115+
xM, xP, transPMs_batch, map(get_concrete, interpreters);
116+
n_MC = 8, logσ2y)
117+
@test cost isa Float64
118+
gr = Zygote.gradient(
119+
ϕ -> neg_elbo_transnorm_gf(
120+
rng, g, f, ϕ, y_o[:, 1:n_batch],
121+
xM[:, 1:n_batch], xP[1:n_batch],
122+
transPMs_batch, interpreters; n_MC = 8, logσ2y),
123+
CA.getdata(ϕ_ini))
124+
@test gr[1] isa Vector
125+
end;
126+
127+
if CUDA.functional()
128+
@testset "neg_elbo_transnorm_gf gpu" begin
129+
ϕ = CuArray(CA.getdata(ϕ_ini))
130+
xMg_batch = CuArray(xM[:, 1:n_batch])
131+
xP_batch = xP[1:n_batch] # used in f which runs on CPU
132+
cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch],
133+
xMg_batch, xP_batch,
134+
transPMs_batch, map(get_concrete, interpreters);
135+
n_MC = 8, logσ2y)
136+
@test cost isa Float64
137+
gr = Zygote.gradient(
138+
ϕ -> neg_elbo_transnorm_gf(
139+
rng, g_flux, f, ϕ, y_o[:, 1:n_batch],
140+
xMg_batch, xP_batch,
141+
transPMs_batch, interpreters; n_MC = 8, logσ2y),
142+
ϕ)
143+
@test gr[1] isa CuVector
144+
@test eltype(gr[1]) == FT
145+
end
146+
end
147+
end #if false
148+
149+

0 commit comments

Comments
 (0)