Skip to content

Commit 6da5b81

Browse files
committed
implement HybridPointSolver on gpu
1 parent 84743f6 commit 6da5b81

12 files changed

+115
-70
lines changed

dev/doubleMM.jl

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,51 @@ using UnicodePlots
1111
using SimpleChains
1212
using Flux
1313
using MLUtils
14+
using CUDA
1415

1516
rng = StableRNG(114)
1617
scenario = NTuple{0, Symbol}()
17-
#scenario = (:use_Flux,)
18+
scenario = (:use_Flux,)
1819

1920
#------ setup synthetic data and training data loader
2021
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
2122
) = gen_hybridcase_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
23+
xM_cpu = xM
24+
if :use_Flux scenario
25+
xM = CuArray(xM_cpu)
26+
end
2227
get_train_loader = (rng; n_batch, kwargs...) -> MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
2328
σ_o = exp(first(y_unc)/2)
2429

2530
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
26-
prob0 = update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
31+
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
2732

2833
#------- pointwise hybrid model fit
2934
#solver = HybridPointSolver(; alg = Adam(0.02), n_batch = 30)
3035
solver = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
3136
#solver = HybridPointSolver(; alg = Adam(), n_batch = 200)
3237
(; ϕ, resopt) = solve(prob0, solver; scenario,
3338
rng, callback = callback_loss(100), maxiters = 1200);
34-
prob0o = update(prob0; ϕg=ϕ.ϕg, θP=ϕ.θP)
35-
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP);
39+
# update the problem with optimized parameters
40+
prob0o = HVI.update(prob0; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
41+
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
3642
scatterplot(θMs_true[1,:], θMs[1,:])
3743
scatterplot(θMs_true[2,:], θMs[2,:])
3844

3945
# do a few steps without minibatching,
4046
# by providing the data rather than the DataLoader
41-
# train_loader0 = get_hybridproblem_train_dataloader(rng, prob0; scenario, n_batch=1000)
42-
# get_train_loader_data = (args...; kwargs...) -> train_loader0.data
43-
# prob1 = update(prob0o; get_train_loader = get_train_loader_data)
44-
prob1 = prob0o
45-
46-
#solver1 = HybridPointSolver(; alg = Adam(0.05), n_batch = n_site)
4747
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
48-
(; ϕ, resopt) = solve(prob1, solver1; scenario, rng,
48+
(; ϕ, resopt) = solve(prob0o, solver1; scenario, rng,
4949
callback = callback_loss(20), maxiters = 600);
50-
prob1o = update(prob1; ϕg=ϕ.ϕg, θP=ϕ.θP)
51-
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP);
50+
prob1o = HVI.update(prob0o; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP);
51+
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario);
5252
scatterplot(θMs_true[1,:], θMs[1,:])
5353
scatterplot(θMs_true[2,:], θMs[2,:])
5454
prob1o.θP
5555
scatterplot(vec(y_true), vec(y_pred))
5656

57+
# still overestimating θMs
58+
5759
() -> begin # with more iterations?
5860
prob2 = prob1o
5961
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
@@ -63,50 +65,55 @@ scatterplot(vec(y_true), vec(y_pred))
6365
prob2o.θP
6466
end
6567

66-
#----------- fit g to true θMs
67-
# and fit gf starting from true parameters
68-
prob = prob0
69-
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
70-
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
71-
72-
function loss_g(ϕg, x, g, transM)
73-
ζMs = g(x, ϕg) # predict the log of the parameters
74-
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
75-
loss = sum(abs2, θMs .- θMs_true)
76-
return loss, θMs
77-
end
78-
loss_g(ϕg0, xM, g, transM)
7968

80-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
81-
Optimization.AutoZygote())
82-
optprob = Optimization.OptimizationProblem(optf, ϕg0);
83-
res = Optimization.solve(optprob, Adam(0.015), callback = callback_loss(100), maxiters = 2000);
84-
85-
ϕg_opt1 = res.u;
86-
l1, θMs = loss_g(ϕg_opt1, xM, g, transM)
87-
#scatterplot(θMs_true[1,:], θMs[1,:])
88-
scatterplot(θMs_true[2,:], θMs[2,:]) # able to fit θMs[2,:]
89-
90-
prob3 = update(prob0, ϕg = ϕg_opt1, θP = θP_true)
91-
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
92-
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
93-
callback = callback_loss(50), maxiters = 600);
94-
prob3o = update(prob3; ϕg=ϕ.ϕg, θP=ϕ.θP)
95-
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP);
96-
scatterplot(θMs_true[2,:], θMs[2,:])
97-
prob3o.θP
98-
scatterplot(vec(y_true), vec(y_pred))
99-
scatterplot(vec(y_true), vec(y_o))
100-
scatterplot(vec(y_pred), vec(y_o))
69+
#----------- fit g to true θMs
70+
() -> begin
71+
# and fit gf starting from true parameters
72+
prob = prob0
73+
g, ϕg0_cpu = get_hybridproblem_MLapplicator(prob; scenario);
74+
ϕg0 = (:use_Flux scenario) ? CuArray(ϕg0_cpu) : ϕg0_cpu
75+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
76+
77+
function loss_g(ϕg, x, g, transM; gpu_handler = HVI.default_GPU_DataHandler)
78+
ζMs = g(x, ϕg) # predict the log of the parameters
79+
ζMs_cpu = gpu_handler(ζMs)
80+
θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each column
81+
loss = sum(abs2, θMs .- θMs_true)
82+
return loss, θMs
83+
end
84+
loss_g(ϕg0, xM, g, transM)
85+
86+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
87+
Optimization.AutoZygote())
88+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
89+
res = Optimization.solve(optprob, Adam(0.015), callback = callback_loss(100), maxiters = 2000);
90+
91+
ϕg_opt1 = res.u;
92+
l1, θMs = loss_g(ϕg_opt1, xM, g, transM)
93+
#scatterplot(θMs_true[1,:], θMs[1,:])
94+
scatterplot(θMs_true[2,:], θMs[2,:]) # able to fit θMs[2,:]
95+
96+
prob3 = HVI.update(prob0, ϕg = Array(ϕg_opt1), θP = θP_true)
97+
solver1 = HybridPointSolver(; alg = Adam(0.01), n_batch = n_site)
98+
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
99+
callback = callback_loss(50), maxiters = 600);
100+
prob3o = HVI.update(prob3; ϕg=cpu_ca(ϕ).ϕg, θP=cpu_ca(ϕ).θP)
101+
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario);
102+
scatterplot(θMs_true[2,:], θMs[2,:])
103+
prob3o.θP
104+
scatterplot(vec(y_true), vec(y_pred))
105+
scatterplot(vec(y_true), vec(y_o))
106+
scatterplot(vec(y_pred), vec(y_o))
101107

102-
() -> begin # optimized loss is indeed lower than with true parameters
103-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
104-
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
105-
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
106-
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc)[1]
107-
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc)[1]
108-
#
109-
loss_gf(vcat(prob2o.ϕg, prob2o.θP), xM, xP, y_o, y_unc)[1]
108+
() -> begin # optimized loss is indeed lower than with true parameters
109+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
110+
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
111+
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
112+
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc)[1]
113+
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc)[1]
114+
#
115+
loss_gf(vcat(prob2o.ϕg, prob2o.θP), xM, xP, y_o, y_unc)[1]
116+
end
110117
end
111118

112119
#----------- Hybrid Variational inference

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ function HVI.construct_3layer_MLApplicator(
5555
construct_ChainsApplicator(rng, g_chain, float_type)
5656
end
5757

58+
function HVI.cpu_ca(ca::CA.ComponentArray)
59+
CA.ComponentArray(cpu(CA.getdata(ca)), CA.getaxes(ca))
60+
end
61+
62+
5863

5964

6065
end # module

src/DoubleMM/f_doubleMM.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple =
2222
(; θP, θM)
2323
end
2424

25+
function HVI.get_hybridproblem_MLapplicator(
26+
rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
27+
ml_engine = select_ml_engine(; scenario)
28+
construct_3layer_MLApplicator(rng, prob, ml_engine; scenario)
29+
end
30+
2531
function HVI.get_hybridproblem_transforms(::DoubleMMCase; scenario::NTuple = ())
2632
(; transP, transM)
2733
end
@@ -91,11 +97,6 @@ function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
9197
)
9298
end
9399

94-
function HVI.get_hybridproblem_MLapplicator(
95-
rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
96-
ml_engine = select_ml_engine(; scenario)
97-
construct_3layer_MLApplicator(rng, prob, ml_engine; scenario)
98-
end
99100

100101

101102

src/HybridSolver.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
2424
f = get_hybridproblem_PBmodel(prob; scenario)
2525
y_global_o = FT[] # TODO
2626
loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP)
27-
#l1 = loss_gf(p0, train_loader...)[1]
27+
# data1 = first(train_loader)
28+
# l1 = loss_gf(p0, first(train_loader)...)[1]
2829
# Zygote.gradient(p0 -> loss_gf(p0, data1...)[1], p0)
2930
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
3031
Optimization.AutoZygote())

src/HybridVariationalInference.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_
3131
get_hybridproblem_par_templates, get_hybridproblem_transforms, get_hybridproblem_train_dataloader,
3232
get_hybridproblem_neg_logden_obs,
3333
get_hybridproblem_n_covar,
34-
update,
34+
#update,
3535
gen_cov_pred
3636
include("AbstractHybridProblem.jl")
3737

@@ -47,6 +47,9 @@ include("gencovar.jl")
4747
export callback_loss
4848
include("util_opt.jl")
4949

50+
export cpu_ca
51+
include("util_ca.jl")
52+
5053
export neg_logden_indep_normal, entropy_MvNormal
5154
include("logden_normal.jl")
5255

src/gf.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@ function gf(g, transM, f, xM, xP, ϕg, θP; gpu_handler = default_GPU_DataHandle
2424
# @show first(ϕg,5)
2525
ζMs = g(xM, ϕg) # predict the log of the parameters
2626
ζMs_cpu = gpu_handler(ζMs)
27+
if θP isa SubArray && !(gpu_handler isa NullGPUDataHandler)
28+
# otherwise Zyote fails on gpu_handler
29+
θP = copy(θP)
30+
end
31+
θP_cpu = gpu_handler(CA.getdata(θP))
2732
θMs = reduce(hcat, map(transM, eachcol(ζMs_cpu))) # transform each column
28-
y_pred_global, y_pred = f(θP, θMs, xP)
33+
y_pred_global, y_pred = f(θP_cpu, θMs, xP)
2934
return y_pred_global, y_pred, θMs
3035
end
3136

@@ -34,7 +39,8 @@ function gf(prob::AbstractHybridProblem, xM, xP, args...; scenario = (), kwargs.
3439
f = get_hybridproblem_PBmodel(prob; scenario)
3540
(; θP, θM) = get_hybridproblem_par_templates(prob; scenario)
3641
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
37-
gf(g, transM, f, xM, xP, ϕg, θP; kwargs...)
42+
ϕg_dev, θP_dev = (:use_Flux scenario) ? (CuArray(ϕg), CuArray(CA.getdata(θP))) : (ϕg, CA.getdata(θP))
43+
gf(g, transM, f, xM, xP, ϕg_dev, θP_dev; kwargs...)
3844
end
3945

4046
"""
@@ -50,7 +56,8 @@ function get_loss_gf(g, transM, f, y_o_global, int_ϕθP::AbstractComponentArray
5056
function loss_gf(p, xM, xP, y_o, y_unc)
5157
σ = exp.(y_unc ./ 2)
5258
pc = int_ϕθP(p)
53-
y_pred_global, y_pred, θMs = gf(g, transM, f, xM, xP, pc.ϕg, pc.θP)
59+
y_pred_global, y_pred, θMs = gf(
60+
g, transM, f, xM, xP, CA.getdata(pc.ϕg), CA.getdata(pc.θP))
5461
loss = sum(abs2, (y_pred .- y_o) ./ σ) + sum(abs2, y_pred_global .- y_o_global)
5562
return loss, y_pred_global, y_pred, θMs
5663
end

src/util_ca.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
cpu_ca(ca::CA.ComponentArray)
3+
4+
Move ComponentArray form gpu to cpu.
5+
"""
6+
function cpu_ca end
7+
# define in FluxExt
8+
9+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml
33

44
@time begin
55
if GROUP == "All" || GROUP == "Basic"
6+
@time @safetestset "test_HybridProblem" include("test_HybridProblem.jl")
67
#@safetestset "test" include("test/test_ComponentArrayInterpreter.jl")
78
@time @safetestset "test_ComponentArrayInterpreter" include("test_ComponentArrayInterpreter.jl")
89
#@safetestset "test" include("test/test_gencovar.jl")

test/test_ComponentArrayInterpreter.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ using ComponentArrays: ComponentArrays as CA
77
component_counts = comp_cnts = (; P=2, M=3, Unc=5)
88
m = ComponentArrayInterpreter(; comp_cnts...)
99
testm = (m) -> begin
10-
@test CM._get_ComponentArrayInterpreter_axes(m) == (CA.Axis(P=1:2, M=3:5, Unc=6:10),)
10+
#type of axes may differ
11+
#@test CM._get_ComponentArrayInterpreter_axes(m) == (CA.Axis(P=1:2, M=3:5, Unc=6:10),)
1112
@test length(m) == 10
1213
v = 1:length(m)
1314
cv = m(v)
1415
@test cv.Unc == 6:10
1516
end
1617
testm(m)
18+
m = get_concrete(m)
1719
testm(get_concrete(m))
1820
Base.isconcretetype(typeof(m))
1921
end;

test/test_Flux.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using StatsFuns: logistic
33
using CUDA, GPUArraysCore
4+
using ComponentArrays: ComponentArrays as CA
45

56
using HybridVariationalInference
67
# @testset "get_default_GPUHandler before loading Flux" begin
@@ -53,3 +54,10 @@ end;
5354
@test size(y) == (n_out, n_site)
5455
end;
5556

57+
@testset "cpu_ca" begin
58+
c1 = CA.ComponentVector(a=(a1=1,a2=2:3),b=3:4)
59+
c1_gpu = gpu(c1)
60+
#cpu(c1_gpu) # fails
61+
@test cpu_ca(c1_gpu) == c1
62+
end;
63+

0 commit comments

Comments
 (0)