Skip to content

Commit 750ab58

Browse files
authored
Merge pull request #16 from EarthyScience/dev
implement HybridPointSolver
2 parents cc7a7b2 + 245f796 commit 750ab58

17 files changed

+475
-145
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
12+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1213
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1314
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
17+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1618
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1719
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1820
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -34,12 +36,14 @@ BlockDiagonals = "0.1.42"
3436
CUDA = "5.5.2"
3537
ChainRulesCore = "1.25"
3638
Combinatorics = "1.0.2"
39+
CommonSolve = "0.2.4"
3740
ComponentArrays = "0.15.19"
3841
Flux = "v0.15.2, 0.16"
3942
GPUArraysCore = "0.1, 0.2"
4043
LinearAlgebra = "1.10.0"
4144
Lux = "1.4.2"
4245
MLUtils = "0.4.5"
46+
Optimization = "3.19.3, 4"
4347
Random = "1.10.0"
4448
SimpleChains = "0.4"
4549
StatsBase = "0.34.4"

dev/doubleMM.jl

Lines changed: 168 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,82 +5,195 @@ using StableRNGs
55
using Random
66
using Statistics
77
using ComponentArrays: ComponentArrays as CA
8-
8+
using Optimization
9+
using OptimizationOptimisers # Adam
10+
using UnicodePlots
911
using SimpleChains
10-
import Flux
12+
using Flux
1113
using MLUtils
12-
import Zygote
13-
1414
using CUDA
15-
using OptimizationOptimisers
16-
using Bijectors
17-
using UnicodePlots
18-
19-
const prob = DoubleMM.DoubleMMCase()
20-
scenario = (:default,)
21-
rng = StableRNG(111)
22-
23-
par_templates = get_hybridproblem_par_templates(prob; scenario)
2415

25-
#n_covar = get_hybridproblem_n_covar(prob; scenario)
26-
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
16+
rng = StableRNG(114)
17+
scenario = NTuple{0, Symbol}()
18+
scenario = (:use_Flux,)
2719

20+
#------ setup synthetic data and training data loader
2821
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29-
) = gen_hybridcase_synthetic(rng, prob; scenario);
30-
31-
n_covar = size(xM,1)
22+
) = gen_hybridcase_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
23+
xM_cpu = xM
24+
if :use_Flux scenario
25+
xM = CuArray(xM_cpu)
26+
end
27+
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc);
28+
batchsize = n_batch, partial = false)
29+
σ_o = exp(first(y_unc)/2)
30+
31+
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
32+
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
33+
34+
#------- pointwise hybrid model fit
35+
solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
36+
#solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
37+
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
38+
(; ϕ, resopt) = solve(prob0, solver; scenario,
39+
rng, callback = callback_loss(100), maxiters = 1200);
40+
# update the problem with optimized parameters
41+
prob0o = HVI.update(prob0; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
42+
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
43+
scatterplot(θMs_true[1,:], θMs[1,:])
44+
scatterplot(θMs_true[2,:], θMs[2,:])
45+
46+
# do a few steps without minibatching,
47+
# by providing the data rather than the DataLoader
48+
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
49+
(; ϕ, resopt) = solve(prob0o, solver1; scenario, rng,
50+
callback = callback_loss(20), maxiters = 600);
51+
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP);
52+
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
53+
scatterplot(θMs_true[1,:], θMs[1,:])
54+
scatterplot(θMs_true[2,:], θMs[2,:])
55+
prob1o.θP
56+
scatterplot(vec(y_true), vec(y_pred))
57+
58+
# still overestimating θMs
59+
60+
() -> begin # with more iterations?
61+
prob2 = prob1o
62+
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
63+
callback = callback_loss(20), maxiters = 600);
64+
prob2o = update(prob2; ϕg=ϕ.ϕg, θP=ϕ.θP)
65+
y_pred_global, y_pred, θMs = gf(prob2o, xM, xP);
66+
prob2o.θP
67+
end
3268

3369

34-
#----- fit g to θMs_true
35-
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
36-
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
70+
#----------- fit g to true θMs
71+
() -> begin
72+
# and fit gf starting from true parameters
73+
prob = prob0
74+
g, ϕg0_cpu = get_hybridproblem_MLapplicator(prob; scenario);
75+
ϕg0 = (:use_Flux scenario) ? CuArray(ϕg0_cpu) : ϕg0_cpu
76+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
77+
78+
function loss_g(ϕg, x, g, transM; gpu_handler = HVI.default_GPU_DataHandler)
79+
ζMs = g(x, ϕg) # predict the log of the parameters
80+
ζMs_cpu = gpu_handler(ζMs)
81+
θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each column
82+
loss = sum(abs2, θMs .- θMs_true)
83+
return loss, θMs
84+
end
85+
loss_g(ϕg0, xM, g, transM)
86+
87+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
88+
Optimization.AutoZygote())
89+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
90+
res = Optimization.solve(optprob, Adam(0.015), callback = callback_loss(100), maxiters = 2000);
91+
92+
ϕg_opt1 = res.u;
93+
l1, θMs = loss_g(ϕg_opt1, xM, g, transM)
94+
#scatterplot(θMs_true[1,:], θMs[1,:])
95+
scatterplot(θMs_true[2,:], θMs[2,:]) # able to fit θMs[2,:]
96+
97+
prob3 = HVI.update(prob0, ϕg = Array(ϕg_opt1), θP = θP_true)
98+
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
99+
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
100+
callback = callback_loss(50), maxiters = 600);
101+
prob3o = HVI.update(prob3; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
102+
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario);
103+
scatterplot(θMs_true[2,:], θMs[2,:])
104+
prob3o.θP
105+
scatterplot(vec(y_true), vec(y_pred))
106+
scatterplot(vec(y_true), vec(y_o))
107+
scatterplot(vec(y_pred), vec(y_o))
37108

38-
function loss_g(ϕg, x, g, transM)
39-
ζMs = g(x, ϕg) # predict the log of the parameters
40-
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
41-
loss = sum(abs2, θMs .- θMs_true)
42-
return loss, θMs
109+
() -> begin # optimized loss is indeed lower than with true parameters
110+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
111+
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
112+
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
113+
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc)[1]
114+
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc)[1]
115+
#
116+
loss_gf(vcat(prob2o.ϕg, prob2o.θP), xM, xP, y_o, y_unc)[1]
117+
end
43118
end
44-
loss_g(ϕg0, xM, g, transM)
119+
120+
#----------- Hybrid Variational inference: HVI
45121

46-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
47-
Optimization.AutoZygote())
48-
optprob = Optimization.OptimizationProblem(optf, ϕg0);
49-
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
122+
using MLUtils
123+
import Zygote
124+
125+
using CUDA
126+
using Bijectors
50127

51-
ϕg_opt1 = res.u;
52-
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
53-
scatterplot(vec(θMs_true), vec(θMs_pred))
128+
solver = HybridPosteriorSolver(; alg = Adam(0.01), n_batch = 60, n_MC = 3)
129+
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
130+
(; ϕ, θP, resopt) = solve(prob0o, solver; scenario,
131+
rng, callback = callback_loss(100), maxiters = 800);
132+
# update the problem with optimized parameters
133+
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=θP)
134+
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
135+
scatterplot(θMs_true[1,:], θMs[1,:])
136+
scatterplot(θMs_true[2,:], θMs[2,:])
137+
hcat(θP_true, θP) # all parameters overestimated
54138

55-
f = get_hybridproblem_PBmodel(prob; scenario)
56-
py = get_hybridproblem_neg_logden_obs(prob; scenario)
57139

58-
#----------- fit g and θP to y_o
59140
() -> begin
60-
# end2end inversion
141+
#n_covar = get_hybridproblem_n_covar(prob; scenario)
142+
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
61143

62-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
63-
ϕg = 1:length(ϕg0), θP = par_templates.θP))
64-
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
144+
n_covar = size(xM, 1)
65145

66-
# Pass the site-data for the batches as separate vectors wrapped in a tuple
67-
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
146+
#----- fit g to θMs_true
147+
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
148+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
68149

69-
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
70-
l1 = loss_gf(p0, train_loader.data...)[1]
150+
function loss_g(ϕg, x, g, transM)
151+
ζMs = g(x, ϕg) # predict the log of the parameters
152+
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
153+
loss = sum(abs2, θMs .- θMs_true)
154+
return loss, θMs
155+
end
156+
loss_g(ϕg0, xM, g, transM)
71157

72-
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
158+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
73159
Optimization.AutoZygote())
74-
optprob = OptimizationProblem(optf, p0, train_loader)
160+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
161+
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
75162

76-
res = Optimization.solve(
77-
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)
163+
ϕg_opt1 = res.u;
164+
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
165+
scatterplot(vec(θMs_true), vec(θMs_pred))
78166

79-
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
80-
scatterplot(vec(θMs_true), vec(θMs))
81-
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
82-
scatterplot(vec(y_pred), vec(y_o))
83-
hcat(par_templates.θP, int_ϕθP(res.u).θP)
167+
f = get_hybridproblem_PBmodel(prob; scenario)
168+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
169+
170+
#----------- fit g and θP to y_o
171+
() -> begin
172+
# end2end inversion
173+
174+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
175+
ϕg = 1:length(ϕg0), θP = par_templates.θP))
176+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
177+
178+
# Pass the site-data for the batches as separate vectors wrapped in a tuple
179+
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
180+
181+
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
182+
l1 = loss_gf(p0, train_loader.data...)[1]
183+
184+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
185+
Optimization.AutoZygote())
186+
optprob = OptimizationProblem(optf, p0, train_loader)
187+
188+
res = Optimization.solve(
189+
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000)
190+
191+
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
192+
scatterplot(vec(θMs_true), vec(θMs))
193+
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
194+
scatterplot(vec(y_pred), vec(y_o))
195+
hcat(par_templates.θP, int_ϕθP(res.u).θP)
196+
end
84197
end
85198

86199
#---------- HVI
@@ -92,8 +205,6 @@ FT = get_hybridproblem_float_type(prob; scenario)
92205
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
93206
ϕ_true = ϕ
94207

95-
96-
97208
() -> begin
98209
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
99210
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
@@ -245,7 +356,7 @@ y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
245356
size(y_pred) # n_obs x n_site, n_sample_pred
246357

247358
σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
248-
σ_o = exp.(y_unc[:,1] / 2)
359+
σ_o = exp.(y_unc[:, 1] / 2)
249360

250361
#describe(σ_o_post)
251362
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ function HVI.construct_3layer_MLApplicator(
5555
construct_ChainsApplicator(rng, g_chain, float_type)
5656
end
5757

58+
function HVI.cpu_ca(ca::CA.ComponentArray)
59+
CA.ComponentArray(cpu(CA.getdata(ca)), CA.getaxes(ca))
60+
end
61+
62+
5863

5964

6065
end # module

0 commit comments

Comments
 (0)