Skip to content

Commit 245f796

Browse files
committed
implement HybridPosteriorSolver
1 parent 6da5b81 commit 245f796

File tree

8 files changed

+162
-92
lines changed

8 files changed

+162
-92
lines changed

dev/doubleMM.jl

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ xM_cpu = xM
2424
if :use_Flux scenario
2525
xM = CuArray(xM_cpu)
2626
end
27-
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
27+
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc);
28+
batchsize = n_batch, partial = false)
2829
σ_o = exp(first(y_unc)/2)
2930

3031
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
3132
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
3233

3334
#------- pointwise hybrid model fit
34-
#solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
35-
solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
35+
solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
36+
#solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
3637
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
3738
(; ϕ, resopt) = solve(prob0, solver; scenario,
3839
rng, callback = callback_loss(100), maxiters = 1200);
@@ -116,70 +117,83 @@ end
116117
end
117118
end
118119

119-
#----------- Hybrid Variational inference
120+
#----------- Hybrid Variational inference: HVI
120121

121122
using MLUtils
122123
import Zygote
123124

124125
using CUDA
125126
using Bijectors
126127

128+
solver = HybridPosteriorSolver(; alg = Adam(0.01), n_batch = 60, n_MC = 3)
129+
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
130+
(; ϕ, θP, resopt) = solve(prob0o, solver; scenario,
131+
rng, callback = callback_loss(100), maxiters = 800);
132+
# update the problem with optimized parameters
133+
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=θP)
134+
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
135+
scatterplot(θMs_true[1,:], θMs[1,:])
136+
scatterplot(θMs_true[2,:], θMs[2,:])
137+
hcat(θP_true, θP) # all parameters overestimated
127138

128-
#n_covar = get_hybridproblem_n_covar(prob; scenario)
129-
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
130139

131-
n_covar = size(xM, 1)
140+
() -> begin
141+
#n_covar = get_hybridproblem_n_covar(prob; scenario)
142+
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
132143

133-
#----- fit g to θMs_true
134-
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
135-
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
144+
n_covar = size(xM, 1)
136145

137-
function loss_g(ϕg, x, g, transM)
138-
ζMs = g(x, ϕg) # predict the log of the parameters
139-
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
140-
loss = sum(abs2, θMs .- θMs_true)
141-
return loss, θMs
142-
end
143-
loss_g(ϕg0, xM, g, transM)
146+
#----- fit g to θMs_true
147+
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
148+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
144149

145-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
146-
Optimization.AutoZygote())
147-
optprob = Optimization.OptimizationProblem(optf, ϕg0);
148-
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
150+
function loss_g(ϕg, x, g, transM)
151+
ζMs = g(x, ϕg) # predict the log of the parameters
152+
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
153+
loss = sum(abs2, θMs .- θMs_true)
154+
return loss, θMs
155+
end
156+
loss_g(ϕg0, xM, g, transM)
157+
158+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
159+
Optimization.AutoZygote())
160+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
161+
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
149162

150-
ϕg_opt1 = res.u;
151-
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
152-
scatterplot(vec(θMs_true), vec(θMs_pred))
163+
ϕg_opt1 = res.u;
164+
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
165+
scatterplot(vec(θMs_true), vec(θMs_pred))
153166

154-
f = get_hybridproblem_PBmodel(prob; scenario)
155-
py = get_hybridproblem_neg_logden_obs(prob; scenario)
167+
f = get_hybridproblem_PBmodel(prob; scenario)
168+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
156169

157-
#----------- fit g and θP to y_o
158-
() -> begin
159-
# end2end inversion
170+
#----------- fit g and θP to y_o
171+
() -> begin
172+
# end2end inversion
160173

161-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
162-
ϕg = 1:length(ϕg0), θP = par_templates.θP))
163-
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
174+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
175+
ϕg = 1:length(ϕg0), θP = par_templates.θP))
176+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
164177

165-
# Pass the site-data for the batches as separate vectors wrapped in a tuple
166-
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
178+
# Pass the site-data for the batches as separate vectors wrapped in a tuple
179+
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
167180

168-
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
169-
l1 = loss_gf(p0, train_loader.data...)[1]
181+
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
182+
l1 = loss_gf(p0, train_loader.data...)[1]
170183

171-
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
172-
Optimization.AutoZygote())
173-
optprob = OptimizationProblem(optf, p0, train_loader)
184+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
185+
Optimization.AutoZygote())
186+
optprob = OptimizationProblem(optf, p0, train_loader)
174187

175-
res = Optimization.solve(
176-
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)
188+
res = Optimization.solve(
189+
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)
177190

178-
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
179-
scatterplot(vec(θMs_true), vec(θMs))
180-
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
181-
scatterplot(vec(y_pred), vec(y_o))
182-
hcat(par_templates.θP, int_ϕθP(res.u).θP)
191+
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
192+
scatterplot(vec(θMs_true), vec(θMs))
193+
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
194+
scatterplot(vec(y_pred), vec(y_o))
195+
hcat(par_templates.θP, int_ϕθP(res.u).θP)
196+
end
183197
end
184198

185199
#---------- HVI

src/AbstractHybridProblem.jl

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ optionally
1717
"""
1818
abstract type AbstractHybridProblem end;
1919

20-
2120
"""
2221
get_hybridproblem_MLapplicator([rng::AbstractRNG,] ::AbstractHybridProblem; scenario=())
2322
@@ -28,9 +27,9 @@ returns a Tuple of
2827
- AbstractModelApplicator
2928
- initial parameter vector
3029
"""
31-
function get_hybridproblem_MLapplicator end
30+
function get_hybridproblem_MLapplicator end
3231

33-
function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario=())
32+
function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario = ())
3433
get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario)
3534
end
3635

@@ -56,16 +55,14 @@ function get_hybridproblem_PBmodel end
5655
Provide a `function(y_obs, ypred) -> Real` that computes the negative logdensity
5756
of the observations, given the predictions.
5857
"""
59-
function get_hybridproblem_neg_logden_obs end
60-
58+
function get_hybridproblem_neg_logden_obs end
6159

6260
"""
6361
get_hybridproblem_par_templates(::AbstractHybridProblem; scenario)
6462
6563
Provide tuple of templates of ComponentVectors `θP` and `θM`.
6664
"""
67-
function get_hybridproblem_par_templates end
68-
65+
function get_hybridproblem_par_templates end
6966

7067
"""
7168
get_hybridproblem_transforms(::AbstractHybridProblem; scenario)
@@ -96,7 +93,7 @@ function get_hybridproblem_n_covar(prob::AbstractHybridProblem; scenario)
9693
train_loader = get_hybridproblem_train_dataloader(Random.default_rng(), prob; scenario)
9794
(xM, xP, y_o, y_unc) = first(train_loader)
9895
n_covar = size(xM, 1)
99-
return(n_covar)
96+
return (n_covar)
10097
end
10198

10299
"""
@@ -118,7 +115,7 @@ function gen_hybridcase_synthetic end
118115
119116
Determine the FloatType for given Case and scenario, defaults to Float32
120117
"""
121-
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario=())
118+
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario = ())
122119
return eltype(get_hybridproblem_par_templates(prob; scenario).θM)
123120
end
124121

@@ -131,20 +128,20 @@ Return a DataLoader that provides a tuple of
131128
- `y_o`: matrix of observations with added noise, with one column per site
132129
- `y_unc`: matrix `sizeof(y_o)` of uncertainty information
133130
"""
134-
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem;
135-
scenario = (), n_batch = 10)
131+
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem;
132+
scenario = (), n_batch = 10)
136133
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, prob; scenario)
137134
xM_gpu = :use_Flux scenario ? CuArray(xM) : xM
138-
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
139-
return(train_loader)
135+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc);
136+
batchsize = n_batch, partial = false)
137+
return (train_loader)
140138
end
141139

142140
function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())
143141
rng::AbstractRNG = Random.default_rng()
144142
get_hybridproblem_train_dataloader(rng, prob; scenario)
145143
end
146144

147-
148145
"""
149146
get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario)
150147
@@ -163,8 +160,5 @@ If there is only single block of all ML-predicted parameters being correlated
163160
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
164161
"""
165162
function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ())
166-
(P=(1,), M=(1,))
163+
(P = (1,), M = (1,))
167164
end
168-
169-
170-

src/HybridSolver.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
2424
f = get_hybridproblem_PBmodel(prob; scenario)
2525
y_global_o = FT[] # TODO
2626
loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP)
27+
# call loss function once
28+
l1 = loss_gf(p0, first(train_loader)...)[1]
2729
# data1 = first(train_loader)
28-
# l1 = loss_gf(p0, first(train_loader)...)[1]
2930
# Zygote.gradient(p0 -> loss_gf(p0, data1...)[1], p0)
3031
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
3132
Optimization.AutoZygote())
32-
optprob = OptimizationProblem(optf, p0, train_loader)
33+
optprob = OptimizationProblem(optf, CA.getdata(p0), train_loader)
3334
res = Optimization.solve(optprob, solver.alg; kwargs...)
3435
(;ϕ = int_ϕθP(res.u), resopt = res)
3536
end
@@ -42,24 +43,54 @@ struct HybridPosteriorSolver{A} <: AbstractHybridSolver
4243
n_MC::Int
4344

4445
end
45-
HybridPosteriorSolver(; alg, n_batch = 10, n_MC = 3) = HybridPointSolver(alg, n_batch, n_MC)
46+
HybridPosteriorSolver(; alg, n_batch = 10, n_MC = 3) = HybridPosteriorSolver(alg, n_batch, n_MC)
4647

4748
function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorSolver;
4849
scenario, rng = Random.default_rng(), kwargs...)
4950
par_templates = get_hybridproblem_par_templates(prob; scenario)
51+
(; θP, θM) = par_templates
5052
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
5153
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
5254
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
53-
θP_true, θMs_true[:, 1], ϕg0, solver.n_batch; transP, transM);
55+
θP, θM, ϕg0, solver.n_batch; transP, transM);
5456
use_gpu = (:use_Flux scenario)
55-
# ϕd = use_gpu ? CuArray(ϕ) : ϕ
56-
# train_loader = get_hybridproblem_train_dataloader(rng, prob; scenario, solver.n_batch)
57-
# f = get_hybridproblem_PBmodel(prob; scenario)
58-
# y_global_o = Float32[] # TODO
59-
# loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP)
60-
# optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
61-
# Optimization.AutoZygote())
62-
# optprob = OptimizationProblem(optf, p0, train_loader)
63-
# res = Optimization.solve(optprob, solver.alg; kwargs...)
57+
ϕ0 = use_gpu ? CuArray(ϕ) : ϕ # TODO replace CuArray by something more general
58+
train_loader = get_hybridproblem_train_dataloader(rng, prob; scenario, solver.n_batch)
59+
f = get_hybridproblem_PBmodel(prob; scenario)
60+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
61+
y_global_o = Float32[] # TODO
62+
loss_elbo = get_loss_elbo(g, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC)
63+
# test loss function once
64+
l0 = loss_elbo(ϕ0, rng, first(train_loader)...)
65+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1],
66+
Optimization.AutoZygote())
67+
optprob = OptimizationProblem(optf, CA.getdata(ϕ0), train_loader)
68+
res = Optimization.solve(optprob, solver.alg; kwargs...)
69+
ϕc = interpreters.μP_ϕg_unc(res.u)
70+
(;ϕ = ϕc, θP = cpu_ca(apply_preserve_axes(transP,ϕc.μP)), resopt = res)
71+
end
72+
73+
"""
74+
Create a loss function for parameter vector ϕ, given
75+
- g(x, ϕ): machine learning model
76+
- transPMS: transformation from unconstrained space to parameter space
77+
- f(θMs, θP): mechanistic model
78+
- interpreters: assigning structure to pure vectors, see neg_elbo_transnorm_gf
79+
- n_MC: number of Monte-Carlo sample to approximate the expected value across distribution
80+
81+
The loss function takes in addition to ϕ, data that changes with minibatch
82+
- rng: random generator
83+
- xM: matrix of covariates, sites in columns
84+
- xP: drivers for the processmodel: Iterator of size n_site
85+
- y_o, y_unc: matrix of observations and uncertainties, sites in columns
86+
"""
87+
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters; n_MC)
88+
let g = g, transPMs = transPMs, f = f, py=py, y_o_global = y_o_global, n_MC = n_MC
89+
interpreters = map(get_concrete, interpreters)
90+
function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc)
91+
neg_elbo_transnorm_gf(rng, ϕ, g, transPMs, f, py,
92+
xM, xP, y_o, y_unc, interpreters; n_MC)
93+
end
94+
end
6495
end
6596

src/elbo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ function generate_ζ(rng, g, ϕ::AbstractVector, xM::AbstractMatrix,
8787
μ_ζMs0 = g(xM, ϕg) # TODO provide μ_ζP to g
8888
ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_starts)
8989
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
90+
# @show size(ζ_resid)
91+
# @show length(interpreters.PMs)
9092
ζ = stack(map(eachcol(ζ_resid)) do r
9193
rc = interpreters.PMs(r)
9294
ζP = μ_ζP .+ rc.P

src/gf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Create a loss function for parameter vector p, given
5252
- int_ϕθP: interpreter attachin axis with compponents ϕg and pc.θP
5353
"""
5454
function get_loss_gf(g, transM, f, y_o_global, int_ϕθP::AbstractComponentArrayInterpreter)
55-
let g = g, transM = transM, f = f, int_ϕθP = int_ϕθP
55+
let g = g, transM = transM, f = f, int_ϕθP = int_ϕθP, y_o_global = y_o_global
5656
function loss_gf(p, xM, xP, y_o, y_unc)
5757
σ = exp.(y_unc ./ 2)
5858
pc = int_ϕθP(p)

src/init_hybrid_params.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function init_hybrid_params(θP, θM, ϕg, n_batch;
3737
ρsP,
3838
ρsM)
3939
ϕ = CA.ComponentVector(;
40-
μP = inverse(transP)(θP),
40+
μP = apply_preserve_axes(inverse(transP),θP),
4141
ϕg = ϕg,
4242
unc = ϕunc0);
4343
#

src/util_ca.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ Move ComponentArray form gpu to cpu.
66
function cpu_ca end
77
# define in FluxExt
88

9+
function apply_preserve_axes(f, ca::CA.ComponentArray)
10+
CA.ComponentArray(f(ca), CA.getaxes(ca))
11+
end
12+
913

0 commit comments

Comments
 (0)