Skip to content

Commit 5996128

Browse files
committed
move apply_process_model to PBMApplicator.apply_model
keep stack in non-GPU call, because mapreduce is very slow (despite then its not compatible with Zygor) for GPUArrays use mapreduce, because stack results in scalar indexing error
1 parent 36851f1 commit 5996128

File tree

6 files changed

+107
-52
lines changed

6 files changed

+107
-52
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
572572
trans_mP=StackedArray(transP, size(ζsP, 2))
573573
trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3))
574574
θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs)
575-
y = apply_process_model(θsP, θsMs, f, xP)
575+
y = f(θsP, θsMs, f, xP)
576576
#(; y, θsP, θsMs) = HVI.apply_f_trans(ζsP, ζsMs, f_allsites, xP; transP, transM);
577577
(y_hmc, θsP_hmc, θsMs_hmc) = (; y, θsP, θsMs);
578578

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ include("logden_normal.jl")
9393
export get_ca_starts, get_ca_ends, get_cor_count
9494
include("cholesky.jl")
9595

96-
export neg_elbo_gtf, sample_posterior, apply_process_model, predict_hvi
96+
export neg_elbo_gtf, sample_posterior, predict_hvi
9797
include("elbo.jl")
9898

9999
export init_hybrid_params, init_hybrid_ϕunc

src/PBMApplicator.jl

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,67 @@
11
"""
2-
AbstractPBMApplicator(θP::AbstractVector, θMs::AbstractMatrix, xP::AbstractMatrix)
3-
42
Abstraction of applying a process-based model with
5-
global parameters, `x`, site-specific parameters, `θMs` (sites in columns),
3+
global parameters, `θP`, site-specific parameters, `θMs` (sites in columns),
64
and site-specific model drivers, `xP` (sites in columns),
75
It returns a matrix of predictions sites in columns.
86
9-
Specific implementations need to implement function `apply_model(app, θP, θMs, xP)`.
7+
Specific implementations need to provide function `apply_model(app, θP, θMs, xP)`.
8+
where
9+
- `θsP` and `θsMs` are shaped according to the output of `generate_ζ`, i.e.
10+
`(n_site_pred x n_par x n_MC)`.
11+
- Results are of shape `(n_obs x n_site_pred x n_MC)`.
12+
13+
They may also provide function `apply_model(app, θP, θMs, xP)` for a sample
14+
of parameters, i.e. where an additional dimension is added to both `θP` and `θMs`.
15+
However, there is a default implementation that mapreduces across these dimensions.
16+
1017
Provided are implementations
11-
- `NullPBMApplicator`: returning its input `θMs` for testing
1218
- `PBMSiteApplicator`: based on a function that computes predictions per site
1319
- `PBMPopulationApplicator`: based on a function that computes predictions for entire population
20+
- `NullPBMApplicator`: returning its input `θMs` for testing
21+
- `PlainPBMApplicator`: based on a function that takes the same arguments as `apply_model`
1422
"""
1523
abstract type AbstractPBMApplicator end
1624

1725
# function apply_model end # already defined in ModelApplicator.jl for ML model
1826

19-
function (app::AbstractPBMApplicator)(θP::AbstractVector, θMs::AbstractMatrix, xP::AbstractMatrix)
27+
function (app::AbstractPBMApplicator)(θP::AbstractArray, θMs::AbstractArray, xP::AbstractMatrix)
2028
apply_model(app, θP, θMs, xP)
2129
end
2230

31+
"""
32+
apply_model(app::AbstractPBMApplicator, θsP::AbstractVector, θsMs::AbstractMatrix, xP::AbstractMatrix)
33+
apply_model(app::AbstractPBMApplicator, θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, xP)
34+
35+
The first variant calls the PBM for one batch of sites.
36+
37+
The second variant calls the PBM for a sample of batches, and stack results.
38+
The default implementation mapreduces the last dimension of `θsP` and θ`sMs` calling the
39+
first variant of `apply_model` for each sample.
40+
"""
41+
# docu in struct
42+
function apply_model(app::AbstractPBMApplicator, θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, xP) where ET
43+
# stack does not work on GPU, see specialized method for GPUArrays below
44+
y_pred = stack(
45+
map(eachcol(CA.getdata(θsP)), eachslice(CA.getdata(θsMs), dims=3)) do θP, θMs
46+
y_global, y_pred_i = app(θP, θMs, xP)
47+
y_pred_i
48+
end)
49+
end
50+
function apply_model(app::AbstractPBMApplicator, θsP::GPUArraysCore.AbstractGPUMatrix, θsMs::GPUArraysCore.AbstractGPUArray{ET,3}, xP) where ET
51+
# stack does not work on GPU, need to resort to slower mapreduce
52+
# for type stability, apply f at first iterate to supply init to mapreduce
53+
P1, Pit = Iterators.peel(eachcol(CA.getdata(θsP)));
54+
Ms1, Msit = Iterators.peel(eachslice(CA.getdata(θsMs), dims=3));
55+
y1 = apply_model(app, P1, Ms1, xP)[2]
56+
y1a = reshape(y1, size(y1)..., 1) # add one dimension
57+
y_pred = mapreduce((a,b) -> cat(a,b; dims=3), Pit, Msit; init=y1a) do θP, θMs
58+
y_global, y_pred_i = app(θP, θMs, xP)
59+
y_pred_i
60+
end
61+
end
62+
63+
64+
2365

2466
"""
2567
NullPBMApplicator()
@@ -119,8 +161,8 @@ struct PBMPopulationApplicator{MFT, IPT, IT, IXT, F} <: AbstractPBMApplicator
119161
int_xP::IXT
120162
end
121163

122-
# let fmap not descend into isP
123-
# @functor PBMPopulationApplicator (θFixm, )
164+
# let fmap not descend into isP, because indexing with isP on cpu is faster
165+
@functor PBMPopulationApplicator (θFixm, )
124166

125167
"""
126168
PBMPopulationApplicator(fθpop, n_batch; θP, θM, θFix, xPvec)
@@ -167,7 +209,13 @@ function apply_model(app::PBMPopulationApplicator, θP::AbstractVector, θMs::Ab
167209
"or compute PBM on CPU.")
168210
end
169211
# repeat θP and concatenate with
212+
# Main.@infiltrate_main
213+
# repeat is 2x slower for Vector and 100 times slower (with allocation) on GPU
214+
# app.isP on CPU is slightly faster than app.isP on GPU
215+
#@benchmark CA.getdata(θP[app.isP])
216+
#@benchmark CA.getdata(repeat(θP', size(θMs,1)))
170217
local θ = hcat(CA.getdata(θP[app.isP]), CA.getdata(θMs), app.θFixm)
218+
#local θ = hcat(CA.getdata(repeat(θP', size(θMs,1))), CA.getdata(θMs), app.θFixm)
171219
local θc = app.intθ(CA.getdata(θ))
172220
local xPc = app.int_xP(CA.getdata(xP))
173221
local pred_sites = app.fθpop(θc, xPc)

src/elbo.jl

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ function predict_hvi(rng, prob::AbstractHybridProblem; scenario=Val(()),
177177
else
178178
f_dev = f
179179
end
180-
y = apply_process_model(θsP, θsMs, f_dev, xP)
180+
#y = apply_process_model(θsP, θsMs, f_dev, xP)
181+
y = f_dev(θsP, θsMs, xP)
181182
(; y, θsP, θsMs, entropy_ζ)
182183
end
183184

@@ -312,31 +313,33 @@ end
312313
# (; y, θP, θMs)
313314
# end
314315

315-
"""
316-
apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP)
316+
# """
317+
# apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP)
317318

318-
Call a PBM applicator for a sample of parameters of each site, and stack results
319+
# Call a PBM applicator for a sample of parameters of each site, and stack results
319320

320-
`θsP` and `θsMs` are shaped according to the output of `generate_ζ`, i.e.
321-
`(n_site_pred x n_par x n_MC)`.
322-
Results are of shape `(n_obs x n_site_pred x n_MC)`.
323-
"""
324-
function apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET
325-
# stack does not work on GPU
326-
# y_pred = stack(map(eachcol(θsP), eachslice(θsMs, dims=3)) do θP, θMs
327-
# y_global, y_pred_i = f(θP, θMs, xP)
328-
# y_pred_i
329-
# end)
330-
#Main.@infiltrate_main
331-
# for type stability, apply f at first sample before mapreduce
332-
P1, Pit = Iterators.peel(eachcol(θsP));
333-
Ms1, Msit = Iterators.peel(eachslice(θsMs, dims=3));
334-
y1 = f(P1, Ms1, xP)[2]
335-
y_pred = mapreduce((a,b) -> cat(a,b;dims=3), Pit, Msit; init=y1) do θP, θMs
336-
y_global, y_pred_i = f(θP, θMs, xP)
337-
y_pred_i
338-
end
339-
end
321+
# `θsP` and `θsMs` are shaped according to the output of `generate_ζ`, i.e.
322+
# `(n_site_pred x n_par x n_MC)`.
323+
# Results are of shape `(n_obs x n_site_pred x n_MC)`.
324+
# """
325+
# function apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET
326+
# error("deprecated, use f(θsP, θsMs, xP)")
327+
# # stack does not work on GPU
328+
# # y_pred = stack(
329+
# # map(eachcol(CA.getdata(θsP)), eachslice(CA.getdata(θsMs), dims=3)) do θP, θMs
330+
# # y_global, y_pred_i = f(θP, θMs, xP)
331+
# # y_pred_i
332+
# # end)
333+
# # for type stability, apply f at first iterate to supply init to mapreduce
334+
# P1, Pit = Iterators.peel(eachcol(CA.getdata(θsP)));
335+
# Ms1, Msit = Iterators.peel(eachslice(CA.getdata(θsMs), dims=3));
336+
# y1 = f(P1, Ms1, xP)[2]
337+
# y1a = reshape(y1, size(y1)..., 1) # add one dimension
338+
# y_pred = mapreduce((a,b) -> cat(a,b; dims=3), Pit, Msit; init=y1a) do θP, θMs
339+
# y_global, y_pred_i = f(θP, θMs, xP)
340+
# y_pred_i
341+
# end
342+
# end
340343

341344
"""
342345
Generate samples of (inv-transformed) model parameters, ζ,

test/test_HybridProblem.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,11 @@ test_with_flux_gpu = (scenario) -> begin
316316
rng = StableRNG(111)
317317
probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf);
318318
# put Applicator to gpu (θFix)
319-
probg = HybridProblem(
320-
probg,
321-
f_batch = fmap(gdev, probg.f_batch),
322-
f_allsites = fmap(gdev, probg.f_allsites))
319+
# moved to solve and predict_hvi
320+
# probg = HybridProblem(
321+
# probg,
322+
# f_batch = fmap(gdev, probg.f_batch),
323+
# f_allsites = fmap(gdev, probg.f_allsites))
323324
#prob = CP.update(probg, transM = identity, transP = identity);
324325
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
325326
n_site, n_batch = get_hybridproblem_n_site_and_batch(probg; scenario = scenf)

test/test_elbo.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ using MLDataDevices
1818
# setup g as FluxNN on gpu
1919
using Flux
2020

21+
#CUDA.device!(4)
22+
2123
ggdev = gpu_device()
2224

23-
#CUDA.device!(4)
2425
rng = StableRNG(111)
2526

2627
const prob = DoubleMM.DoubleMMCase()
@@ -143,7 +144,7 @@ test_scenario = (scenario) -> begin
143144
= vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc)
144145
#hcat(ϕ_ini, ϕ, _ϕ)[1:4,:]
145146
#hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:]
146-
n_predict = 80000
147+
n_predict = 8000
147148
xM_batch = xM[:, 1:n_batch]
148149
_ζsP, _ζsMs, _σ = @inferred (
149150
# @descend_code_warntype (
@@ -196,18 +197,18 @@ test_scenario = (scenario) -> begin
196197
reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3)))
197198
cor_PMs = cor(residPMst')
198199
@test cor_PMs[1,2] ρsP_true[1] atol=0.02
199-
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.02)) # no correlations P,M
200+
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.1)) # no correlations P,M
200201
@test cor_PMs[3,4] ρsM_true[1] atol=0.02
201-
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.02)) # no correlations M1, M2
202+
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.1)) # no correlations M1, M2
202203
@test cor_PMs[5,6] ρsM_true[1] atol=0.02
203-
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.02)) # no correlations M1, M2
204+
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.1)) # no correlations M1, M2
204205
end
205206
test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g)
206207
@testset "predict_hvi check sd" begin
207208
# test if uncertainty and reshaping is propagated
208209
# here inverse the predicted θs and then test distribution
209210
probcu = HybridProblem(probc, ϕunc=ϕunc_true);
210-
n_sample_pred = 24_000
211+
n_sample_pred = 2_400
211212
(; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred);
212213
#size(_ζsMs), size(θsMs)
213214
#size(_ζsP), size(θsP)
@@ -221,7 +222,7 @@ test_scenario = (scenario) -> begin
221222
test_distζ(_ζsP2, _ζsMs2, ϕunc_true, ζMs_g2)
222223
end;
223224
end;
224-
end # if covar in scenario
225+
end # if covarK2 in scenario
225226

226227
if ggdev isa MLDataDevices.AbstractGPUDevice
227228
@testset "generate_ζ gpu $(last(CP._val_value(scenario)))" begin
@@ -390,14 +391,13 @@ test_scenario = (scenario) -> begin
390391
θsPc = int_mP(θsP)
391392
@test all(θsPc[:r0, :] .> 0)
392393
#
393-
y = apply_process_model(θsP, θsMs, f_pred, xP)
394+
y = @inferred f_pred(θsP, θsMs, xP)
394395
@test y isa Array
395396
@test size(y) == (size(y_o)..., n_sample_pred)
396397
end
397398

398399
if ggdev isa MLDataDevices.AbstractGPUDevice
399400
@testset "predict_hvi gpu $(last(CP._val_value(scenario)))" begin
400-
n_sample_pred = 32
401401
ϕ_ini_g = ggdev(CA.getdata(ϕ_ini))
402402
xMg = ggdev(xM)
403403
n_sample_pred = 30
@@ -407,17 +407,20 @@ test_scenario = (scenario) -> begin
407407
sample_posterior(rng, g_gpu, ϕ_ini_g, xMg;
408408
int_μP_ϕg_unc, int_unc,
409409
transP, transM,
410-
cdev = cpu_device(),
410+
#cdev = cpu_device(),
411+
cdev = identity, # do not transfer to CPU
411412
n_sample_pred, cor_ends, pbm_covar_indices)
412413
)
414+
# this variant without the problem, does not attach axes
413415
@test θsP isa AbstractMatrix
414416
@test θsMs isa AbstractArray{T,3} where {T}
415417
int_mP = ComponentArrayInterpreter(int_P, (size(θsP, 2),))
416-
θsPc = int_mP(θsP)
417-
@test all(θsPc[:r0, :] .> 0)
418+
@test all(int_mP(θsP)[:r0, :] .> 0)
418419
#
419-
y = apply_process_model(θsP, θsMs, f_pred, xP)
420-
@test y isa Array
420+
xP_dev = ggdev(xP);
421+
f_pred_dev = fmap(ggdev, f_pred)
422+
y = @inferred f_pred_dev(θsP, θsMs, xP_dev)
423+
@test y isa GPUArraysCore.AbstractGPUArray
421424
@test size(y) == (size(y_o)..., n_sample_pred)
422425
end
423426
# @testset "predict_hvi also f on gpu" begin

0 commit comments

Comments
 (0)