Skip to content

Commit 4d4da56

Browse files
committed
encapsulate generating transformations and interpreters
implement predict_gf function
1 parent 4a1d5c2 commit 4d4da56

File tree

6 files changed

+277
-173
lines changed

6 files changed

+277
-173
lines changed

dev/doubleMM.jl

Lines changed: 91 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
5353

5454
f = gen_hybridcase_PBmodel(case; scenario)
5555

56+
#----------- fit g and θP to y_o
5657
() -> begin
57-
#----------- fit g and θP to y_o
5858
# end2end inversion
5959

6060
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
@@ -82,67 +82,78 @@ f = gen_hybridcase_PBmodel(case; scenario)
8282
end
8383

8484
#---------- HVI
85-
# TODO think about good general initializations
86-
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
87-
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
88-
mean_σ_o_MC = 0.006042
89-
90-
# correlation matrices
91-
ρsP = zeros(sum(1:(n_θP - 1)))
92-
ρsM = zeros(sum(1:(n_θM - 1)))
93-
94-
ϕunc = CA.ComponentVector(;
95-
logσ2_logP = logσ2_logP,
96-
coef_logσ2_logMs = coef_logσ2_logMs,
97-
ρsP,
98-
ρsM)
99-
int_unc = ComponentArrayInterpreter(ϕunc)
100-
101-
# for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
102-
ϕunc0 = CA.ComponentVector(;
103-
logσ2_logP = fill(-10.0, n_θP),
104-
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
105-
ρsP,
106-
ρsM)
107-
10885
logσ2y = 2 .* log.(σ_o)
10986
n_MC = 3
87+
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
88+
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
89+
ϕ_true = ϕ
11090

111-
transPMs_batch = as(
112-
(P = as(Array, asℝ₊, n_θP),
113-
Ms = as(Array, asℝ₊, n_θM, n_batch)))
114-
transPMs_all = as(
115-
(P = as(Array, asℝ₊, n_θP),
116-
Ms = as(Array, asℝ₊, n_θM, n_site)))
117-
118-
n_ϕg = length(ϕg_opt1)
119-
ϕt_true = θ = CA.ComponentVector(;
120-
μP = θP_true,
121-
ϕg = ϕg_opt1,
122-
unc = ϕunc);
123-
trans_gu = as(
124-
(μP = as(Array, asℝ₊, n_θP),
125-
ϕg = as(Array, n_ϕg),
126-
unc = as(Array, length(ϕunc))))
127-
trans_g = as(
128-
(μP = as(Array, asℝ₊, n_θP),
129-
ϕg = as(Array, n_ϕg)))
130-
131-
#const
132-
int_PMs_batch = ComponentArrayInterpreter(CA.ComponentVector(; θP = θP_true,
133-
θMs = CA.ComponentMatrix(
134-
zeros(n_θM, n_batch), first(CA.getaxes(θMs_true)), CA.Axis(i = 1:n_batch))))
135-
136-
interpreters = interpreters_g = map(get_concrete,
137-
(;
138-
μP_ϕg_unc = ComponentArrayInterpreter(ϕt_true),
139-
PMs = int_PMs_batch,
140-
unc = ComponentArrayInterpreter(ϕunc)
141-
))
142-
143-
ϕ_true = inverse_ca(trans_gu, ϕt_true)
91+
() -> begin
92+
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
93+
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
94+
mean_σ_o_MC = 0.006042
95+
96+
# correlation matrices
97+
ρsP = zeros(sum(1:(n_θP - 1)))
98+
ρsM = zeros(sum(1:(n_θM - 1)))
99+
100+
ϕunc = CA.ComponentVector(;
101+
logσ2_logP = logσ2_logP,
102+
coef_logσ2_logMs = coef_logσ2_logMs,
103+
ρsP,
104+
ρsM)
105+
int_unc = ComponentArrayInterpreter(ϕunc)
106+
107+
# for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
108+
ϕunc0 = CA.ComponentVector(;
109+
logσ2_logP = fill(-10.0, n_θP),
110+
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
111+
ρsP,
112+
ρsM)
113+
114+
transPMs_batch = as(
115+
(P = as(Array, asℝ₊, n_θP),
116+
Ms = as(Array, asℝ₊, n_θM, n_batch)))
117+
transPMs_allsites = as(
118+
(P = as(Array, asℝ₊, n_θP),
119+
Ms = as(Array, asℝ₊, n_θM, n_site)))
120+
121+
n_ϕg = length(ϕg_opt1)
122+
ϕt_true = θ = CA.ComponentVector(;
123+
μP = θP_true,
124+
ϕg = ϕg_opt1,
125+
unc = ϕunc)
126+
trans_gu = as(
127+
(μP = as(Array, asℝ₊, n_θP),
128+
ϕg = as(Array, n_ϕg),
129+
unc = as(Array, length(ϕunc))))
130+
trans_g = as(
131+
(μP = as(Array, asℝ₊, n_θP),
132+
ϕg = as(Array, n_ϕg)))
133+
134+
#const
135+
int_PMs_batch = ComponentArrayInterpreter(CA.ComponentVector(; θP = θP_true,
136+
θMs = CA.ComponentMatrix(
137+
zeros(n_θM, n_batch), first(CA.getaxes(θMs_true)), CA.Axis(i = 1:n_batch))))
138+
139+
interpreters = interpreters_g = map(get_concrete,
140+
(;
141+
μP_ϕg_unc = ComponentArrayInterpreter(ϕt_true),
142+
PMs = int_PMs_batch,
143+
unc = ComponentArrayInterpreter(ϕunc)
144+
))
145+
146+
ϕ_true = inverse_ca(trans_gu, ϕt_true)
147+
end
148+
149+
ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕ_true[[:unc]]); # scratch
150+
#
151+
# true values
144152
ϕ_ini = ζ = vcat(ϕ_true[[:μP, :ϕg]] .* 1.2, ϕ_true[[:unc]]); # slight disturbance
145-
ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕunc0); # scratch
153+
# hardcoded from HMC inversion
154+
ϕ_ini.unc.coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
155+
ϕ_ini.unc.logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
156+
mean_σ_o_MC = 0.006042
146157

147158
# test cost function and gradient
148159
() -> begin
@@ -161,10 +172,10 @@ end
161172
train_loader = MLUtils.DataLoader((xM, y_o), batchsize = n_batch)
162173

163174
optf = Optimization.OptimizationFunction(
164-
(ζg, data) -> begin
175+
(ϕ, data) -> begin
165176
xM, y_o = data
166177
neg_elbo_transnorm_gf(
167-
rng, g, f, ζg, y_o, xM, transPMs_batch,
178+
rng, g, f, ϕ, y_o, xM, transPMs_batch,
168179
map(get_concrete, interpreters_g); n_MC = 5, logσ2y)
169180
end,
170181
Optimization.AutoZygote())
@@ -181,7 +192,7 @@ g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario
181192

182193
# otpimize using LUX
183194
() -> begin
184-
using Lux
195+
#using Lux
185196
g_lux = Lux.Chain(
186197
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
187198
Lux.Dense(n_covar => n_covar * 4, tanh),
@@ -208,18 +219,19 @@ function fcost(ϕ)
208219
n_MC = 8, logσ2y = logσ2y)
209220
end
210221
fcost(ϕ)
211-
Zygote.gradient(fcost, ϕ) |> cpu;
222+
#Zygote.gradient(fcost, ϕ) |> cpu;
212223
gr = Zygote.gradient(fcost, CA.getdata(ϕ));
213-
gr_c = CA.ComponentArray(gr[1], CA.getaxes(ϕ)...)
224+
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
214225

215226
train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch)
216227

217228
optf = Optimization.OptimizationFunction(
218-
(ζg, data) -> begin
229+
(ϕ, data) -> begin
219230
xM, y_o = data
220-
neg_elbo_transnorm_gf(
221-
rng, g_flux, f, ζg, y_o, xM, transPMs_batch,
222-
map(get_concrete, interpreters_g); n_MC = 5, logσ2y)
231+
fcost(ϕ)
232+
# neg_elbo_transnorm_gf(
233+
# rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
234+
# map(get_concrete, interpreters); n_MC = 5, logσ2y)
223235
end,
224236
Optimization.AutoZygote())
225237
optprob = Optimization.OptimizationProblem(
@@ -230,40 +242,31 @@ res = res_gpu = Optimization.solve(
230242
# start from zero
231243
() -> begin
232244
optprob = Optimization.OptimizationProblem(
233-
optf, CA.getdata(ϕ_ini0) |> Flux.gpu, train_loader);
245+
optf, CA.getdata(ϕ_ini0) |> Flux.gpu, train_loader)
234246
res = res_gpu = Optimization.solve(
235-
optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 4_000);
247+
optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 4_000)
236248
end
237249

238-
ζ_VIc = interpreters_g.μP_ϕg_unc(res.u |> Flux.cpu)
239-
ζMs_VI = g(xM, ζ_VIc.ϕg)
240-
ϕunc_VI = int_unc(ζ_VIc.unc)
250+
ζ_VIc = interpreters.μP_ϕg_unc(res.u |> Flux.cpu)
251+
ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu
252+
ϕunc_VI = interpreters.unc(ζ_VIc.unc)
241253

242254
hcat(θP_true, exp.(ζ_VIc.μP))
243255
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
244256
#lineplot!(plt, 0.0, 1.1, identity)
245257
#
246-
hcat(ϕunc, ϕunc_VI) # need to compare to MC sample
258+
hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
247259
# hard to estimate for original very small theta's but otherwise good
248260

249261
# test predicting correct obs-uncertainty of predictive posterior
250262
# TODO reuse g_flux rather than g
251263
n_sample_pred = 200
252-
intm_PMs_gen = ComponentArrayInterpreter(CA.ComponentVector(; θP = θP_true,
253-
θMs = CA.ComponentMatrix(
254-
zeros(n_θM, n_site), first(CA.getaxes(θMs_true)), CA.Axis(i = 1:n_sample_pred))))
255-
256-
ζs, _ = HVI.generate_ζ(rng, g, f, res.u |> Flux.cpu, xM,
257-
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred)
258-
# ζ = ζs[:,1]
259-
θsc = stack(
260-
ζ -> CA.getdata(CA.ComponentVector(
261-
TransformVariables.transform(transPMs_all, ζ))),
262-
eachcol(ζs));
263-
y_pred = stack(map-> first(HVI.predict_y(ζ, f, transPMs_all)), eachcol(ζs)));
264-
265-
size(y_pred)
266-
σ_o_post = mapslices(std, y_pred; dims = 3)[:, :, 1];
264+
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters;
265+
get_transPMs, get_ca_int_PMs, n_sample_pred);
266+
size(y_pred) # n_obs x n_site, n_sample_pred
267+
268+
σ_o_post = dropdims(std(y_pred; dims = 3), dims=3)
269+
267270
#describe(σ_o_post)
268271
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),
269272
mean(σ_o_post, dims = 2), sqrt.(mean(abs2, σ_o_post, dims = 2)))

src/HybridVariationalInference.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ include("logden_normal.jl")
4444
#export - all internal
4545
include("cholesky.jl")
4646

47-
export neg_elbo_transnorm_gf
47+
export neg_elbo_transnorm_gf, predict_gf
4848
include("elbo.jl")
4949

50+
export init_hybrid_params
51+
include("init_hybrid_params.jl")
52+
5053
export DoubleMM
5154
include("DoubleMM/DoubleMM.jl")
5255

src/elbo.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,37 @@ expected value of the likelihood of observations.
2121
function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractMatrix,
2222
transPMs, interpreters::NamedTuple;
2323
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler())
24-
ζ, logdetΣ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
25-
ζ_cpu = gpu_data_handler(ζ) # differentiable fetch to CPU in Flux package extension
26-
#ζi = first(eachcol(ζ_cpu))
27-
nLy = reduce(+, map(eachcol(ζ_cpu)) do ζi
24+
ζs, logdetΣ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
25+
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
26+
#ζi = first(eachcol(ζs_cpu))
27+
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
2828
y_pred_i, logjac = predict_y(ζi, f, transPMs)
2929
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
3030
nLy1 - logjac
3131
end) / n_MC
32-
ent = entropy_MvNormal(size(ζ, 1), logdetΣ) # defined in logden_normal
33-
nLy - ent
34-
end
35-
36-
function predict_gf(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix,
37-
transPMs, interpreters::NamedTuple;
38-
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler())
39-
ζ, logdetΣ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
40-
ζ_cpu = gpu_data_handler(ζ) # differentiable fetch to CPU in Flux package extension
41-
#ζi = first(eachcol(ζ_cpu))
42-
nLy = reduce(+, map(eachcol(ζ_cpu)) do ζi
43-
y_pred_i, logjac = predict_y(ζi, f, transPMs)
44-
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
45-
nLy1 - logjac
46-
end) / n_MC
47-
ent = entropy_MvNormal(size(ζ, 1), logdetΣ) # defined in logden_normal
32+
ent = entropy_MvNormal(size(ζs, 1), logdetΣ) # defined in logden_normal
4833
nLy - ent
4934
end
5035

36+
"""
37+
predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters;
38+
get_transPMs, get_ca_int_PMs, n_sample_pred=200,
39+
gpu_data_handler=get_default_GPUHandler())
5140
41+
Prediction function for hybrid model. Retuns an Array `(n_obs, n_site, n_sample_pred)`.
42+
"""
43+
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters;
44+
get_transPMs, get_ca_int_PMs, n_sample_pred=200,
45+
gpu_data_handler=get_default_GPUHandler())
46+
n_site = size(xM, 2)
47+
intm_PMs_gen = get_ca_int_PMs(n_site)
48+
tans_PMs_gen = get_transPMs(n_site)
49+
ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM),
50+
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred)
51+
ζs_cpu = gpu_data_handler(ζs) #
52+
y_pred = stack(map-> first(predict_y(ζ, f, tans_PMs_gen)), eachcol(ζs_cpu)));
53+
y_pred
54+
end
5255

5356
"""
5457
Generate samples of (inv-transformed) model parameters, ζ, and Log-Determinant
@@ -144,7 +147,7 @@ function _create_random(rng, ::GPUArraysCore.AbstractGPUVector{T}, dims...) wher
144147
# ignores rng
145148
# https://discourse.julialang.org/t/help-using-cuda-zygote-and-random-numbers/123458/4?u=bgctw
146149
# Zygote.@ignore CUDA.randn(rng, dims...)
147-
Zygote.@ignore CUDA.randn(dims...)
150+
ChainRulesCore.@ignore_derivatives CUDA.randn(dims...)
148151
end
149152

150153
"""

src/init_hybrid_params.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
3+
4+
Setup ComponentVector of parameters to optimize, and associated tools.
5+
Returns a NamedTuple of
6+
- ϕ: A ComponentVector of parameters to optimize
7+
- transPMs_batch, interpreters: Transformations and interpreters as
8+
required by `neg_elbo_transnorm_gf`.
9+
- get_transPMs: a function returning transformations `(n_site) -> (;P,Ms)`
10+
- get_ca_int_PMs: a function returning ComponentArrayInterpreter for PMs vector
11+
with PMs shaped as a matrix of `n_site` columns of `θM`
12+
13+
# Arguments
14+
- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
15+
- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator`
16+
- `n_batch`: the number of sites to predicted in each mini-batch
17+
- `transP`, `transM`: the Transformations for the global and site-dependent parameters
18+
"""
19+
function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
20+
n_θP = length(θP)
21+
n_θM = length(θM)
22+
n_ϕg = length(ϕg)
23+
# zero correlation matrices
24+
ρsP = zeros(sum(1:(n_θP - 1)))
25+
ρsM = zeros(sum(1:(n_θM - 1)))
26+
ϕunc0 = CA.ComponentVector(;
27+
logσ2_logP = fill(-10.0, n_θP),
28+
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
29+
ρsP,
30+
ρsM)
31+
ϕt = CA.ComponentVector(;
32+
μP = θP,
33+
ϕg = ϕg,
34+
unc = ϕunc0);
35+
#
36+
get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM
37+
function get_transPMs_inner(n_site)
38+
transPMs = as(
39+
(P = as(Array, transP, n_θP),
40+
Ms = as(Array, transM, n_θM, n_site)))
41+
end
42+
end
43+
transPMs_batch = get_transPMs(n_batch)
44+
trans_gu = as(
45+
(μP = as(Array, asℝ₊, n_θP),
46+
ϕg = as(Array, n_ϕg),
47+
unc = as(Array, length(ϕunc0))))
48+
ϕ = inverse_ca(trans_gu, ϕt)
49+
# trans_g = as(
50+
# (μP = as(Array, asℝ₊, n_θP),
51+
# ϕg = as(Array, n_ϕg)))
52+
#
53+
get_ca_int_PMs = let
54+
function get_ca_int_PMs_inner(n_site)
55+
ComponentArrayInterpreter(CA.ComponentVector(; θP,
56+
θMs = CA.ComponentMatrix(
57+
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site))))
58+
end
59+
60+
end
61+
interpreters = map(get_concrete,
62+
(;
63+
μP_ϕg_unc = ComponentArrayInterpreter(ϕt),
64+
PMs = get_ca_int_PMs(n_batch),
65+
unc = ComponentArrayInterpreter(ϕunc0)
66+
))
67+
(;ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs)
68+
end
69+

0 commit comments

Comments
 (0)