Skip to content

Commit 36851f1

Browse files
committed
move PBMApplicator to gdev_P int solve and predict_hvi
replace stack in apply_process_model by mapreduce
1 parent 8492874 commit 36851f1

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

src/HybridSolver.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
134134
train_loader_dev = train_loader
135135
end
136136
f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false)
137+
if gdevs.gdev_P isa MLDataDevices.AbstractGPUDevice
138+
f_dev = fmap(gdevs.gdev_P, f)
139+
else
140+
f_dev = f
141+
end
142+
137143
py = get_hybridproblem_neg_logden_obs(prob; scenario)
138144

139145
priors_θP_mean, priors_θMs_mean = construct_priors_θ_mean(
@@ -142,7 +148,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
142148
y_global_o = Float32[] # TODO
143149

144150
loss_elbo = get_loss_elbo(
145-
g_dev, transP, transMs, f, py, y_global_o;
151+
g_dev, transP, transMs, f_dev, py, y_global_o;
146152
solver.n_MC, solver.n_MC_cap, cor_ends, priors_θP_mean, priors_θMs_mean,
147153
cdev=infer_cdev(gdevs), pbm_covars, θP, int_unc, int_μP_ϕg_unc)
148154
# test loss function once

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ include("gencovar.jl")
8484
export callback_loss
8585
include("util_opt.jl")
8686

87-
export cpu_ca
87+
export cpu_ca, apply_preserve_axes
8888
include("util_ca.jl")
8989

9090
export neg_logden_indep_normal, entropy_MvNormal

src/elbo.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,12 @@ function predict_hvi(rng, prob::AbstractHybridProblem; scenario=Val(()),
172172
is_predict_batch = (n_site_pred == n_batch)
173173
@assert size(xP, 2) == n_site_pred
174174
f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=!is_predict_batch)
175-
y = apply_process_model(θsP, θsMs, f, xP)
175+
if gdevs.gdev_P isa MLDataDevices.AbstractGPUDevice
176+
f_dev = fmap(gdevs.gdev_P, f)
177+
else
178+
f_dev = f
179+
end
180+
y = apply_process_model(θsP, θsMs, f_dev, xP)
176181
(; y, θsP, θsMs, entropy_ζ)
177182
end
178183

@@ -317,10 +322,20 @@ Call a PBM applicator for a sample of parameters of each site, and stack results
317322
Results are of shape `(n_obs x n_site_pred x n_MC)`.
318323
"""
319324
function apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET
320-
y_pred = stack(map(eachcol(θsP), eachslice(θsMs, dims=3)) do θP, θMs
325+
# stack does not work on GPU
326+
# y_pred = stack(map(eachcol(θsP), eachslice(θsMs, dims=3)) do θP, θMs
327+
# y_global, y_pred_i = f(θP, θMs, xP)
328+
# y_pred_i
329+
# end)
330+
#Main.@infiltrate_main
331+
# for type stability, apply f at first sample before mapreduce
332+
P1, Pit = Iterators.peel(eachcol(θsP));
333+
Ms1, Msit = Iterators.peel(eachslice(θsMs, dims=3));
334+
y1 = f(P1, Ms1, xP)[2]
335+
y_pred = mapreduce((a,b) -> cat(a,b;dims=3), Pit, Msit; init=y1) do θP, θMs
321336
y_global, y_pred_i = f(θP, θMs, xP)
322337
y_pred_i
323-
end)
338+
end
324339
end
325340

326341
"""

src/util_ca.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
44
Move ComponentArray form gpu to cpu.
55
"""
6-
#function cpu_ca end
7-
# define in FluxExt
86
function cpu_ca(ca::CA.ComponentArray)
97
CA.ComponentArray(cpu_device()(CA.getdata(ca)), CA.getaxes(ca))
108
end
119

10+
"""
11+
apply_preserve_axes(f, ca::ComponentArray)
1212
13+
Apply callable `f(x)` to the data inside `ca`, assume that the result has
14+
the same shape, and return a new `ComponentArray` with the same axes
15+
as in `ca`.
16+
"""
1317
function apply_preserve_axes(f, ca::CA.ComponentArray)
1418
CA.ComponentArray(f(CA.getdata(ca)), CA.getaxes(ca))
1519
end

0 commit comments

Comments
 (0)