Skip to content

Commit c931205

Browse files
authored
Vectorized Example process model executed on GPU (#21)
* matrix version and neglogden on GPU * Implement Bijector Exp that also works on AD on GPU * implement broadcasted doubleMM f_PBM * move n_batch from solver to problem * remove problem responsibilty of putting dataloader to gpu but do it in fitting and predicting functions * fix error in solver of properly updating problem.phi_unc * replace transformations by Exp because exlementwise(exp) failed on AD on GPU * remove n_site from gen_hybridproblem_synthetic
1 parent 945faad commit c931205

21 files changed

+584
-222
lines changed

dev/doubleMM.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,34 @@ gdev = :use_gpu ∈ scenario ? gpu_device() : identity
2828
cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity
2929

3030
#------ setup synthetic data and training data loader
31-
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
31+
prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario);
32+
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
3233
) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
33-
#n_site = get_hybridproblem_n_site(DoubleMM.DoubleMMCase(); scenario)
34+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
3435
ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
3536
i_sites = 1:n_site
36-
xM_cpu = xM;
37-
xM = xM_cpu |> gdev;
38-
get_train_loader = (; n_batch, kwargs...) -> MLUtils.DataLoader(
37+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
38+
train_dataloader = MLUtils.DataLoader(
3939
(xM, xP, y_o, y_unc, 1:n_site);
4040
batchsize = n_batch, partial = false)
4141
σ_o = exp.(y_unc[:, 1] / 2)
42-
4342
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
44-
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
43+
prob0 = HVI.update(prob0_; train_dataloader);
4544
#tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario)
4645

4746
#------- pointwise hybrid model fit
48-
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_batch = 30)
47+
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01))
4948
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 30)
5049
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
5150
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
52-
n_batches_in_epoch = n_site ÷ solver_point.n_batch
51+
n_batches_in_epoch = n_site ÷ n_batch
5352
n_epoch = 80
5453
(; ϕ, resopt, probo) = solve(prob0, solver_point; scenario,
5554
rng, callback = callback_loss(n_batches_in_epoch * 10),
5655
maxiters = n_batches_in_epoch * n_epoch);
5756
# update the problem with optimized parameters
5857
prob0o = probo;
59-
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
58+
y_pred_global, y_pred, θMs = gf(prob0o, scenario);
6059
plt = scatterplot(θMs_true[1, :], θMs[1, :]);
6160
lineplot!(plt, 0, 1)
6261
scatterplot(θMs_true[2, :], θMs[2, :])
@@ -149,10 +148,10 @@ probh = prob0o # start from point optimized to infer uncertainty
149148
#probh = prob1o # start from point optimized to infer uncertainty
150149
#probh = prob0 # start from no information
151150
solver_post = HybridPosteriorSolver(;
152-
alg = OptimizationOptimisers.Adam(0.01), n_batch = min(50, n_site), n_MC = 3)
151+
alg = OptimizationOptimisers.Adam(0.01), n_MC = 3)
153152
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
154-
n_batches_in_epoch = n_site ÷ solver_post.n_batch
155-
n_epoch = 80
153+
n_batches_in_epoch = n_site ÷ n_batch
154+
n_epoch = 40
156155
(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario,
157156
rng, callback = callback_loss(n_batches_in_epoch * 5),
158157
maxiters = n_batches_in_epoch * n_epoch,
@@ -213,6 +212,7 @@ end
213212
n_sample_pred = 400
214213
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
215214
(θ2_indep, y2_indep) = (θ, y)
215+
#(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed
216216
end
217217

218218
() -> begin # otpimize using LUX
@@ -246,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])
246246

247247
# test predicting correct obs-uncertainty of predictive posterior
248248
n_sample_pred = 400
249-
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP; scenario, n_sample_pred);
249+
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred);
250250
(θ2, y2) = (θ, y)
251251
size(y) # n_obs x n_site, n_sample_pred
252252
size(θ) # n_θP + n_site * n_θM x n_sample
@@ -506,12 +506,13 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
506506
using JLD2
507507
fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2"
508508
jldsave(fname, false, IOStream; chain)
509-
chain = load(fname, "chain"; iotype = IOStream)
509+
chain = load(fname, "chain"; iotype = IOStream);
510510
end
511511

512512
#ζi = first(eachrow(Array(chain)))
513+
f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
513514
ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain)));
514-
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
515+
(; θ, y) = HVI.predict_ζf(ζs, f_allsites, xP, trans_PMs_gen, intm_PMs_gen);
515516
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
516517

517518

dev/negLogDensity.pdf

-10.1 KB
Binary file not shown.

dev/r1_density.pdf

-9.96 KB
Binary file not shown.

dev/ys_density.pdf

-11.6 KB
Binary file not shown.

src/AbstractHybridProblem.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ For a specific prob, provide functions that specify details
1212
- `get_hybridproblem_train_dataloader` (may use `construct_dataloader_from_synthetic`)
1313
- `get_hybridproblem_priors`
1414
- `get_hybridproblem_n_covar`
15-
- `get_hybridproblem_n_site`
15+
- `get_hybridproblem_n_site_and_batch`
1616
optionally
1717
- `gen_hybridproblem_synthetic`
1818
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
@@ -125,11 +125,11 @@ function get_hybridproblem_pbmpar_covars(::AbstractHybridProblem; scenario)
125125
end
126126

127127
"""
128-
get_hybridproblem_n_site(::AbstractHybridProblem; scenario)
128+
get_hybridproblem_n_site_and_batch(::AbstractHybridProblem; scenario)
129129
130130
Provide the number of sites.
131131
"""
132-
function get_hybridproblem_n_site end
132+
function get_hybridproblem_n_site_and_batch end
133133

134134

135135
"""
@@ -172,30 +172,51 @@ function get_hybridproblem_train_dataloader end
172172
scenario = (), n_batch)
173173
174174
Construct a dataloader based on `gen_hybridproblem_synthetic`.
175-
gdev is applied to xM.
176-
If :f_on_gpu is in scenario tuple, gdev is also applied to `xP`, `y_o`, and `y_unc`,
177-
to put the entire data to gpu.
178-
Alternatively, gdev could be applied to the dataloader, then for each
179-
iteration the subset of data is separately transferred to gpu.
180175
"""
181176
function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHybridProblem;
182177
scenario = (), n_batch,
183-
gdev = :use_gpu scenario ? gpu_device() : identity,
178+
#gdev = :use_gpu ∈ scenario ? gpu_device() : identity,
184179
)
185180
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario)
186181
n_site = size(xM,2)
187182
@assert length(xP) == n_site
188183
@assert size(y_o,2) == n_site
189184
@assert size(y_unc,2) == n_site
190185
i_sites = 1:n_site
191-
xM_dev = gdev(xM)
192-
xP_dev, y_o_dev, y_unc_dev = :f_on_gpu scenario ?
193-
(gdev(xP), gdev(y_o), gdev(y_unc)) : (xP, y_o, y_unc)
194-
train_loader = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
186+
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites);
195187
batchsize = n_batch, partial = false)
196188
return (train_loader)
197189
end
198190

191+
192+
"""
193+
gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader,
194+
scenario = (),
195+
gdev = gpu_device(),
196+
gdev_M = :use_gpu ∈ scenario ? gdev : identity,
197+
gdev_P = :f_on_gpu ∈ scenario ? gdev : identity,
198+
batchsize = dataloader.batchsize,
199+
partial = dataloader.partial
200+
)
201+
202+
Put relevant parts of the DataLoader to gpu, depending on scenario.
203+
"""
204+
function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader;
205+
scenario = (),
206+
gdev = gpu_device(),
207+
gdev_M = :use_gpu scenario ? gdev : identity,
208+
gdev_P = :f_on_gpu scenario ? gdev : identity,
209+
batchsize = dataloader.batchsize,
210+
partial = dataloader.partial
211+
)
212+
xM, xP, y_o, y_unc, i_sites = dataloader.data
213+
xM_dev = gdev_M(xM)
214+
xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc))
215+
train_loader_dev = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
216+
batchsize, partial)
217+
return(train_loader_dev)
218+
end
219+
199220
# function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())
200221
# rng::AbstractRNG = Random.default_rng()
201222
# get_hybridproblem_train_dataloader(rng, prob; scenario)

src/ComponentArrayInterpreter.jl

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
66
Interface for Type that implements
77
- `as_ca(::AbstractArray, interpreter) -> ComponentArray`
8+
- `ComponentArrays.getaxes(interpreter)`
89
- `Base.length(interpreter) -> Int`
910
1011
When called on a vector, forwards to `as_ca`.
12+
13+
There is a default implementation for Base.length based on ComponentArrays.getaxes.
1114
"""
1215
abstract type AbstractComponentArrayInterpreter end
1316

@@ -18,6 +21,11 @@ Returns a ComponentArray with underlying data `v`.
1821
"""
1922
function as_ca end
2023

24+
function Base.length(cai::AbstractComponentArrayInterpreter)
25+
prod(_axis_length.(CA.getaxes(cai)))
26+
end
27+
28+
2129
(interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter)
2230

2331
"""
@@ -36,9 +44,13 @@ function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {A
3644
CA.ComponentArray(vr, AX)
3745
end
3846

39-
function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
40-
#sum(length, typeof(AX).parameters[1])
41-
prod(_axis_length.(AX))
47+
# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
48+
# #sum(length, typeof(AX).parameters[1])
49+
# prod(_axis_length.(AX))
50+
# end
51+
52+
function CA.getaxes(int::StaticComponentArrayInterpreter{AX}) where {AX}
53+
AX
4254
end
4355

4456
get_concrete(cai::StaticComponentArrayInterpreter) = cai
@@ -63,10 +75,11 @@ function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
6375
CA.ComponentArray(vr, cai.axes)
6476
end
6577

66-
function Base.length(cai::ComponentArrayInterpreter)
67-
prod(_axis_length.(cai.axes))
78+
function CA.getaxes(cai::ComponentArrayInterpreter)
79+
cai.axes
6880
end
6981

82+
7083
get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}()
7184

7285

@@ -120,6 +133,10 @@ function ComponentArrayInterpreter(
120133
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N
121134
ComponentArrayInterpreter(CA.getaxes(ca), n_dims)
122135
end
136+
function ComponentArrayInterpreter(
137+
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N
138+
ComponentArrayInterpreter(CA.getaxes(cai), n_dims)
139+
end
123140
function ComponentArrayInterpreter(
124141
axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
125142
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
@@ -131,12 +148,17 @@ function ComponentArrayInterpreter(
131148
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N
132149
ComponentArrayInterpreter(n_dims, CA.getaxes(ca))
133150
end
151+
function ComponentArrayInterpreter(
152+
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N
153+
ComponentArrayInterpreter(n_dims, CA.getaxes(cai))
154+
end
134155
function ComponentArrayInterpreter(
135156
n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M}
136157
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
137158
ComponentArrayInterpreter(axes_ext)
138159
end
139160

161+
140162
# ambuiguity with two empty Tuples (edge prob that does not make sense)
141163
# Empty ComponentVector with no other array dimensions -> empty componentVector
142164
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
@@ -156,6 +178,8 @@ _axis_length(::CA.FlatAxis) = 0
156178
_axis_length(::CA.UnitRange) = 0
157179

158180
"""
181+
flatten1(cv::CA.ComponentVector)
182+
159183
Removes the highest level of keys.
160184
Keeps the reference to the underlying data, but changes the axis.
161185
If first-level vector has no sub-names, an error (Aguement Error tuple must be non-empty)
@@ -174,3 +198,16 @@ function flatten1(cv::CA.ComponentVector)
174198
CA.ComponentVector(cv, first(CA.getaxes(cv_new)))
175199
end
176200
end
201+
202+
203+
"""
204+
get_positions(cai::AbstractComponentArrayInterpreter)
205+
206+
Create a NamedTuple of integer indices for each component.
207+
Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector.
208+
"""
209+
function get_positions(cai::AbstractComponentArrayInterpreter)
210+
@assert length(CA.getaxes(cai)) == 1
211+
cv = cai(1:length(cai))
212+
(; (k => cv[k] for k in keys(cv))... )
213+
end

0 commit comments

Comments
 (0)