Skip to content

Commit 56adfbf

Browse files
committed
implement broadcasted doubleMM f_PBM
move n_batch from solver to problem remove problem responsibilty of putting dataloader to gpu but do it in fitting and predicting functions fix error in solver of properly updating problem.phi_unc
1 parent 05dfc1f commit 56adfbf

16 files changed

+331
-214
lines changed

dev/doubleMM.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,34 @@ gdev = :use_gpu ∈ scenario ? gpu_device() : identity
2828
cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity
2929

3030
#------ setup synthetic data and training data loader
31+
prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario);
3132
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
32-
) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
33-
#n_site = get_hybridproblem_n_site(DoubleMM.DoubleMMCase(); scenario)
33+
) = gen_hybridproblem_synthetic(rng, prob0_; scenario);
34+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
3435
ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
3536
i_sites = 1:n_site
36-
xM_cpu = xM;
37-
xM = xM_cpu |> gdev;
38-
get_train_loader = (; n_batch, kwargs...) -> MLUtils.DataLoader(
37+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
38+
train_dataloader = MLUtils.DataLoader(
3939
(xM, xP, y_o, y_unc, 1:n_site);
4040
batchsize = n_batch, partial = false)
4141
σ_o = exp.(y_unc[:, 1] / 2)
42-
4342
# assign the train_loader, otherwise it eatch time creates another version of synthetic data
44-
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
43+
prob0 = HVI.update(prob0_; train_dataloader);
4544
#tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario)
4645

4746
#------- pointwise hybrid model fit
48-
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_batch = 30)
47+
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01))
4948
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 30)
5049
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
5150
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
52-
n_batches_in_epoch = n_site ÷ solver_point.n_batch
51+
n_batches_in_epoch = n_site ÷ n_batch
5352
n_epoch = 80
5453
(; ϕ, resopt, probo) = solve(prob0, solver_point; scenario,
5554
rng, callback = callback_loss(n_batches_in_epoch * 10),
5655
maxiters = n_batches_in_epoch * n_epoch);
5756
# update the problem with optimized parameters
5857
prob0o = probo;
59-
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
58+
y_pred_global, y_pred, θMs = gf(prob0o, scenario);
6059
plt = scatterplot(θMs_true[1, :], θMs[1, :]);
6160
lineplot!(plt, 0, 1)
6261
scatterplot(θMs_true[2, :], θMs[2, :])
@@ -149,10 +148,10 @@ probh = prob0o # start from point optimized to infer uncertainty
149148
#probh = prob1o # start from point optimized to infer uncertainty
150149
#probh = prob0 # start from no information
151150
solver_post = HybridPosteriorSolver(;
152-
alg = OptimizationOptimisers.Adam(0.01), n_batch = min(50, n_site), n_MC = 3)
151+
alg = OptimizationOptimisers.Adam(0.01), n_MC = 3)
153152
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
154-
n_batches_in_epoch = n_site ÷ solver_post.n_batch
155-
n_epoch = 80
153+
n_batches_in_epoch = n_site ÷ n_batch
154+
n_epoch = 40
156155
(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario,
157156
rng, callback = callback_loss(n_batches_in_epoch * 5),
158157
maxiters = n_batches_in_epoch * n_epoch,
@@ -213,6 +212,7 @@ end
213212
n_sample_pred = 400
214213
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
215214
(θ2_indep, y2_indep) = (θ, y)
215+
#(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed
216216
end
217217

218218
() -> begin # otpimize using LUX
@@ -246,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])
246246

247247
# test predicting correct obs-uncertainty of predictive posterior
248248
n_sample_pred = 400
249-
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP; scenario, n_sample_pred);
249+
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred);
250250
(θ2, y2) = (θ, y)
251251
size(y) # n_obs x n_site, n_sample_pred
252252
size(θ) # n_θP + n_site * n_θM x n_sample
@@ -506,12 +506,13 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
506506
using JLD2
507507
fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2"
508508
jldsave(fname, false, IOStream; chain)
509-
chain = load(fname, "chain"; iotype = IOStream)
509+
chain = load(fname, "chain"; iotype = IOStream);
510510
end
511511

512512
#ζi = first(eachrow(Array(chain)))
513+
f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
513514
ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain)));
514-
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
515+
(; θ, y) = HVI.predict_ζf(ζs, f_allsites, xP, trans_PMs_gen, intm_PMs_gen);
515516
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);
516517

517518

dev/negLogDensity.pdf

-10.1 KB
Binary file not shown.

dev/r1_density.pdf

-9.96 KB
Binary file not shown.

dev/ys_density.pdf

-11.6 KB
Binary file not shown.

src/AbstractHybridProblem.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ For a specific prob, provide functions that specify details
1212
- `get_hybridproblem_train_dataloader` (may use `construct_dataloader_from_synthetic`)
1313
- `get_hybridproblem_priors`
1414
- `get_hybridproblem_n_covar`
15-
- `get_hybridproblem_n_site`
15+
- `get_hybridproblem_n_site_and_batch`
1616
optionally
1717
- `gen_hybridproblem_synthetic`
1818
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
@@ -125,11 +125,11 @@ function get_hybridproblem_pbmpar_covars(::AbstractHybridProblem; scenario)
125125
end
126126

127127
"""
128-
get_hybridproblem_n_site(::AbstractHybridProblem; scenario)
128+
get_hybridproblem_n_site_and_batch(::AbstractHybridProblem; scenario)
129129
130130
Provide the number of sites.
131131
"""
132-
function get_hybridproblem_n_site end
132+
function get_hybridproblem_n_site_and_batch end
133133

134134

135135
"""
@@ -172,30 +172,51 @@ function get_hybridproblem_train_dataloader end
172172
scenario = (), n_batch)
173173
174174
Construct a dataloader based on `gen_hybridproblem_synthetic`.
175-
gdev is applied to xM.
176-
If :f_on_gpu is in scenario tuple, gdev is also applied to `xP`, `y_o`, and `y_unc`,
177-
to put the entire data to gpu.
178-
Alternatively, gdev could be applied to the dataloader, then for each
179-
iteration the subset of data is separately transferred to gpu.
180175
"""
181176
function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHybridProblem;
182177
scenario = (), n_batch,
183-
gdev = :use_gpu scenario ? gpu_device() : identity,
178+
#gdev = :use_gpu ∈ scenario ? gpu_device() : identity,
184179
)
185180
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario)
186181
n_site = size(xM,2)
187182
@assert length(xP) == n_site
188183
@assert size(y_o,2) == n_site
189184
@assert size(y_unc,2) == n_site
190185
i_sites = 1:n_site
191-
xM_dev = gdev(xM)
192-
xP_dev, y_o_dev, y_unc_dev = :f_on_gpu scenario ?
193-
(gdev(xP), gdev(y_o), gdev(y_unc)) : (xP, y_o, y_unc)
194-
train_loader = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
186+
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites);
195187
batchsize = n_batch, partial = false)
196188
return (train_loader)
197189
end
198190

191+
192+
"""
193+
gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader,
194+
scenario = (),
195+
gdev = gpu_device(),
196+
gdev_M = :use_gpu ∈ scenario ? gdev : identity,
197+
gdev_P = :f_on_gpu ∈ scenario ? gdev : identity,
198+
batchsize = dataloader.batchsize,
199+
partial = dataloader.partial
200+
)
201+
202+
Put relevant parts of the DataLoader to gpu, depending on scenario.
203+
"""
204+
function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader;
205+
scenario = (),
206+
gdev = gpu_device(),
207+
gdev_M = :use_gpu scenario ? gdev : identity,
208+
gdev_P = :f_on_gpu scenario ? gdev : identity,
209+
batchsize = dataloader.batchsize,
210+
partial = dataloader.partial
211+
)
212+
xM, xP, y_o, y_unc, i_sites = dataloader.data
213+
xM_dev = gdev_M(xM)
214+
xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc))
215+
train_loader_dev = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
216+
batchsize, partial)
217+
return(train_loader_dev)
218+
end
219+
199220
# function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())
200221
# rng::AbstractRNG = Random.default_rng()
201222
# get_hybridproblem_train_dataloader(rng, prob; scenario)

src/DoubleMM/f_doubleMM.jl

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,49 @@ end
136136
# (; n_covar, n_batch, n_θM, n_θP)
137137
# end
138138

139+
# function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (),
140+
# gdev = :f_on_gpu ∈ scenario ? gpu_device() : identity,
141+
# )
142+
# #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
143+
# par_templates = get_hybridproblem_par_templates(prob; scenario)
144+
# intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
145+
# let θFix = gdev(θFix), intθ = get_concrete(intθ)
146+
# function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
147+
# pred_sites = applyf(f_doubleMM, θMs, θP, θFix, xP, intθ)
148+
# pred_global = eltype(pred_sites)[]
149+
# return pred_global, pred_sites
150+
# end
151+
# end
152+
# end
153+
139154
function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (),
155+
use_all_sites = false,
140156
gdev = :f_on_gpu scenario ? gpu_device() : identity,
141157
)
158+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
159+
n_site_batch = use_all_sites ? n_site : n_batch
142160
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
143161
par_templates = get_hybridproblem_par_templates(prob; scenario)
144-
intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
145-
let θFix = gdev(θFix), intθ = get_concrete(intθ)
146-
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
147-
pred_sites = applyf(f_doubleMM, θMs, θP, θFix, x, intθ)
162+
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
163+
θFix = repeat(θFix1', n_site_batch)
164+
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
165+
isP = repeat(axes(par_templates.θP,1)', n_site_batch)
166+
let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP=isP, n_site_batch=n_site_batch
167+
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
168+
@assert length(xP) == n_site_batch
169+
@assert size(θMs,2) == n_site_batch
170+
# convert vector of tuples to tuple of matricesByRows
171+
# need to supply xP as vectorOfTuples to work with DataLoader
172+
# k = first(keys(xP[1]))
173+
xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
174+
#stack(map(r -> r[k], xP))'
175+
stack(map(r -> r[k], xP); dims = 1)
176+
end)...)
177+
#xPM = map(transpose, xPM1)
178+
# make sure the same order of columns as in intθ
179+
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix
180+
θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs)', θFixd)
181+
pred_sites = f_doubleMM(θ, xPM, intθ)'
148182
pred_global = eltype(pred_sites)[]
149183
return pred_global, pred_sites
150184
end
@@ -157,6 +191,7 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = ()
157191
end
158192
end
159193

194+
160195
function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())
161196
neg_logden_indep_normal
162197
end
@@ -173,25 +208,28 @@ const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]
173208
# const xP_S2 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
174209

175210
HVI.get_hybridproblem_n_covar(prob::DoubleMMCase; scenario) = 5
176-
function HVI.get_hybridproblem_n_site(prob::DoubleMMCase; scenario)
211+
function HVI.get_hybridproblem_n_site_and_batch(prob::DoubleMMCase; scenario)
212+
n_batch = 20
213+
n_site = 800
177214
if (:few_sites scenario)
178-
return(100)
215+
n_site = 100
179216
elseif (:sites20 scenario)
180-
return(20)
217+
n_site = 20
181218
end
182-
800
219+
(n_site, n_batch)
183220
end
184221

185222
function HVI.get_hybridproblem_train_dataloader(prob::DoubleMMCase; scenario = (),
186-
n_batch, rng::AbstractRNG = StableRNG(111), kwargs...
223+
rng::AbstractRNG = StableRNG(111), kwargs...
187224
)
225+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
188226
construct_dataloader_from_synthetic(rng, prob; scenario, n_batch, kwargs...)
189227
end
190228

191229
function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
192230
scenario = ())
193231
n_covar_pc = 2
194-
n_site = get_hybridproblem_n_site(prob; scenario)
232+
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
195233
n_covar = get_hybridproblem_n_covar(prob; scenario)
196234
n_θM = length(θM)
197235
FloatType = get_hybridproblem_float_type(prob; scenario)
@@ -201,7 +239,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
201239
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
202240
# normalize to be distributed around the prescribed true values
203241
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
204-
f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity)
242+
f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity, use_all_sites = true)
205243
xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site)
206244
θP = par_templates.θP
207245
y_global_true, y_true = f(θP, θMs_true, xP)

0 commit comments

Comments
 (0)