Skip to content

Commit 84743f6

Browse files
committed
implement HybridPointSolver on cpu
1 parent cc7a7b2 commit 84743f6

File tree

10 files changed

+279
-64
lines changed

10 files changed

+279
-64
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
12+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1213
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1314
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
17+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1618
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1719
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1820
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -34,12 +36,14 @@ BlockDiagonals = "0.1.42"
3436
CUDA = "5.5.2"
3537
ChainRulesCore = "1.25"
3638
Combinatorics = "1.0.2"
39+
CommonSolve = "0.2.4"
3740
ComponentArrays = "0.15.19"
3841
Flux = "v0.15.2, 0.16"
3942
GPUArraysCore = "0.1, 0.2"
4043
LinearAlgebra = "1.10.0"
4144
Lux = "1.4.2"
4245
MLUtils = "0.4.5"
46+
Optimization = "3.19.3, 4"
4347
Random = "1.10.0"
4448
SimpleChains = "0.4"
4549
StatsBase = "0.34.4"

dev/doubleMM.jl

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,123 @@ using StableRNGs
55
using Random
66
using Statistics
77
using ComponentArrays: ComponentArrays as CA
8-
8+
using Optimization
9+
using OptimizationOptimisers # Adam
10+
using UnicodePlots
911
using SimpleChains
10-
import Flux
12+
using Flux
13+
using MLUtils
14+
15+
rng = StableRNG(114)
16+
scenario = NTuple{0, Symbol}()
17+
#scenario = (:use_Flux,)
18+
19+
#------ setup synthetic data and training data loader
20+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
21+
) = gen_hybridcase_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
22+
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
23+
σ_o = exp(first(y_unc)/2)
24+
25+
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
26+
prob0 = update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
27+
28+
#------- pointwise hybrid model fit
29+
#solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
30+
solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
31+
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
32+
(; ϕ, resopt) = solve(prob0, solver; scenario,
33+
rng, callback = callback_loss(100), maxiters = 1200);
34+
prob0o = update(prob0; ϕg=ϕ.ϕg, θP=ϕ.θP)
35+
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP);
36+
scatterplot(θMs_true[1,:], θMs[1,:])
37+
scatterplot(θMs_true[2,:], θMs[2,:])
38+
39+
# do a few steps without minibatching,
40+
# by providing the data rather than the DataLoader
41+
# train_loader0 = get_hybridproblem_train_dataloader(rng, prob0; scenario, n_batch=1000)
42+
# get_train_loader_data = (args...; kwargs...) -> train_loader0.data
43+
# prob1 = update(prob0o; get_train_loader = get_train_loader_data)
44+
prob1 = prob0o
45+
46+
#solver1 = HybridPointSolver(; alg = Adam(0.05), n_batch = n_site)
47+
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
48+
(; ϕ, resopt) = solve(prob1, solver1; scenario, rng,
49+
callback = callback_loss(20), maxiters = 600);
50+
prob1o = update(prob1; ϕg=ϕ.ϕg, θP=ϕ.θP)
51+
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP);
52+
scatterplot(θMs_true[1,:], θMs[1,:])
53+
scatterplot(θMs_true[2,:], θMs[2,:])
54+
prob1o.θP
55+
scatterplot(vec(y_true), vec(y_pred))
56+
57+
() -> begin # with more iterations?
58+
prob2 = prob1o
59+
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
60+
callback = callback_loss(20), maxiters = 600);
61+
prob2o = update(prob2; ϕg=ϕ.ϕg, θP=ϕ.θP)
62+
y_pred_global, y_pred, θMs = gf(prob2o, xM, xP);
63+
prob2o.θP
64+
end
65+
66+
#----------- fit g to true θMs
67+
# and fit gf starting from true parameters
68+
prob = prob0
69+
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
70+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
71+
72+
function loss_g(ϕg, x, g, transM)
73+
ζMs = g(x, ϕg) # predict the log of the parameters
74+
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
75+
loss = sum(abs2, θMs .- θMs_true)
76+
return loss, θMs
77+
end
78+
loss_g(ϕg0, xM, g, transM)
79+
80+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
81+
Optimization.AutoZygote())
82+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
83+
res = Optimization.solve(optprob, Adam(0.015), callback = callback_loss(100), maxiters = 2000);
84+
85+
ϕg_opt1 = res.u;
86+
l1, θMs = loss_g(ϕg_opt1, xM, g, transM)
87+
#scatterplot(θMs_true[1,:], θMs[1,:])
88+
scatterplot(θMs_true[2,:], θMs[2,:]) # able to fit θMs[2,:]
89+
90+
prob3 = update(prob0, ϕg = ϕg_opt1, θP = θP_true)
91+
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
92+
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
93+
callback = callback_loss(50), maxiters = 600);
94+
prob3o = update(prob3; ϕg=ϕ.ϕg, θP=ϕ.θP)
95+
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP);
96+
scatterplot(θMs_true[2,:], θMs[2,:])
97+
prob3o.θP
98+
scatterplot(vec(y_true), vec(y_pred))
99+
scatterplot(vec(y_true), vec(y_o))
100+
scatterplot(vec(y_pred), vec(y_o))
101+
102+
() -> begin # optimized loss is indeed lower than with true parameters
103+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
104+
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
105+
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
106+
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc)[1]
107+
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc)[1]
108+
#
109+
loss_gf(vcat(prob2o.ϕg, prob2o.θP), xM, xP, y_o, y_unc)[1]
110+
end
111+
112+
#----------- Hybrid Variational inference
113+
11114
using MLUtils
12115
import Zygote
13116

14117
using CUDA
15-
using OptimizationOptimisers
16118
using Bijectors
17-
using UnicodePlots
18119

19-
const prob = DoubleMM.DoubleMMCase()
20-
scenario = (:default,)
21-
rng = StableRNG(111)
22-
23-
par_templates = get_hybridproblem_par_templates(prob; scenario)
24120

25121
#n_covar = get_hybridproblem_n_covar(prob; scenario)
26122
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
27123

28-
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29-
) = gen_hybridcase_synthetic(rng, prob; scenario);
30-
31-
n_covar = size(xM,1)
32-
124+
n_covar = size(xM, 1)
33125

34126
#----- fit g to θMs_true
35127
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
@@ -92,8 +184,6 @@ FT = get_hybridproblem_float_type(prob; scenario)
92184
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
93185
ϕ_true = ϕ
94186

95-
96-
97187
() -> begin
98188
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
99189
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
@@ -245,7 +335,7 @@ y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
245335
size(y_pred) # n_obs x n_site, n_sample_pred
246336

247337
σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
248-
σ_o = exp.(y_unc[:,1] / 2)
338+
σ_o = exp.(y_unc[:, 1] / 2)
249339

250340
#describe(σ_o_post)
251341
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),

src/AbstractHybridProblem.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario=())
123123
end
124124

125125
"""
126-
get_hybridproblem_train_dataloader([rng,] ::AbstractHybridProblem; scenario)
126+
get_hybridproblem_train_dataloader([rng,] ::AbstractHybridProblem; scenario, n_batch)
127127
128128
Return a DataLoader that provides a tuple of
129129
- `xM`: matrix of covariates, with one column per site
@@ -132,9 +132,8 @@ Return a DataLoader that provides a tuple of
132132
- `y_unc`: matrix `sizeof(y_o)` of uncertainty information
133133
"""
134134
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem;
135-
scenario = ())
135+
scenario = (), n_batch = 10)
136136
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, prob; scenario)
137-
n_batch = 10
138137
xM_gpu = :use_Flux scenario ? CuArray(xM) : xM
139138
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
140139
return(train_loader)

src/DoubleMM/f_doubleMM.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
struct DoubleMMCase <: AbstractHybridProblem end
22

33

4-
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
5-
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
4+
const θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
5+
const θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
66

7-
transP = elementwise(exp)
8-
transM = Stacked(elementwise(identity), elementwise(exp))
7+
const transP = elementwise(exp)
8+
const transM = Stacked(elementwise(identity), elementwise(exp))
99

1010

1111
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
@@ -54,13 +54,13 @@ end
5454
# return Float32
5555
# end
5656

57-
const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
57+
const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.1]
5858
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
5959

6060
function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
6161
scenario = ())
6262
n_covar_pc = 2
63-
n_site = 200
63+
n_site = 800
6464
n_covar = 5
6565
n_θM = length(θM)
6666
FloatType = get_hybridproblem_float_type(prob; scenario)

src/HybridProblem.jl

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,77 @@
1-
struct HybridProblem <: AbstractHybridProblem
1+
struct HybridProblem <: AbstractHybridProblem
22
θP
33
θM
44
f
55
g
66
ϕg
7-
py
7+
py
88
transP
99
transM
1010
cor_starts # = (P=(1,),M=(1,))
11-
train_loader
11+
get_train_loader
1212
# inner constructor to constrain the types
1313
function HybridProblem(
14-
θP::CA.ComponentVector, θM::CA.ComponentVector,
15-
g::AbstractModelApplicator, ϕg::AbstractVector,
16-
f::Function,
17-
py::Function,
18-
transM::Union{Function, Bijectors.Transform},
19-
transP::Union{Function, Bijectors.Transform},
20-
train_loader::DataLoader,
21-
cor_starts::NamedTuple = (P=(1,), M=(1,)))
22-
new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, train_loader)
14+
θP::CA.ComponentVector, θM::CA.ComponentVector,
15+
g::AbstractModelApplicator, ϕg::AbstractVector,
16+
f::Function,
17+
py::Function,
18+
transM::Union{Function, Bijectors.Transform},
19+
transP::Union{Function, Bijectors.Transform},
20+
#train_loader::DataLoader,
21+
# return a function that constructs the trainloader based on n_batch
22+
get_train_loader::Function,
23+
cor_starts::NamedTuple = (P = (1,), M = (1,)))
24+
new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, get_train_loader)
2325
end
2426
end
2527

26-
function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector,
27-
# note no ϕg argument and g_chain unconstrained
28-
g_chain, f::Function,
29-
args...; rng = Random.default_rng(), kwargs...)
28+
function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector,
29+
# note no ϕg argument and g_chain unconstrained
30+
g_chain, f::Function,
31+
args...; rng = Random.default_rng(), kwargs...)
3032
# dispatches on type of g_chain
3133
g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM))
3234
HybridProblem(θP, θM, g, ϕg, f, args...; kwargs...)
3335
end
3436

37+
function HybridProblem(prob::AbstractHybridProblem; scenario = ())
38+
(; θP, θM) = get_hybridproblem_par_templates(prob; scenario)
39+
g, ϕg = get_hybridproblem_MLapplicator(prob; scenario)
40+
f = get_hybridproblem_PBmodel(prob; scenario)
41+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
42+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
43+
get_train_loader = let prob = prob, scenario = scenario
44+
function inner_get_train_loader(rng::AbstractRNG; kwargs...)
45+
get_hybridproblem_train_dataloader(rng::AbstractRNG, prob; scenario, kwargs...)
46+
end
47+
end
48+
cor_starts = get_hybridproblem_cor_starts(prob; scenario)
49+
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
50+
end
51+
52+
function update(prob::HybridProblem;
53+
θP::CA.ComponentVector = prob.θP,
54+
θM::CA.ComponentVector = prob.θM,
55+
g::AbstractModelApplicator = prob.g, ϕg::AbstractVector = prob.ϕg,
56+
f::Function = prob.f,
57+
py::Function = prob.py,
58+
transM::Union{Function, Bijectors.Transform} = prob.transM,
59+
transP::Union{Function, Bijectors.Transform} = prob.transP,
60+
get_train_loader::Function = prob.get_train_loader,
61+
cor_starts::NamedTuple = prob.cor_starts)
62+
# prob.θP = θP
63+
# prob.θM = θM
64+
# prob.f = f
65+
# prob.g = g
66+
# prob.ϕg = ϕg
67+
# prob.py = py
68+
# prob.transM = transM
69+
# prob.transP = transP
70+
# prob.cor_starts = cor_starts
71+
# prob.get_train_loader = get_train_loader
72+
HybridProblem(θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts)
73+
end
74+
3575
function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ())
3676
(; θP = prob.θP, θM = prob.θM)
3777
end
@@ -54,12 +94,12 @@ function get_hybridproblem_PBmodel(prob::HybridProblem; scenario::NTuple = ())
5494
prob.f
5595
end
5696

57-
function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario::NTuple = ());
97+
function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario::NTuple = ())
5898
prob.g, prob.ϕg
5999
end
60100

61-
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::HybridProblem; scenario = ())
62-
return(prob.train_loader)
101+
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::HybridProblem; kwargs...)
102+
return prob.get_train_loader(rng; kwargs...)
63103
end
64104

65105
function get_hybridproblem_cor_starts(prob::HybridProblem; scenario = ())
@@ -69,6 +109,3 @@ end
69109
# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ())
70110
# eltype(prob.θM)
71111
# end
72-
73-
74-

0 commit comments

Comments
 (0)