Skip to content

Commit 4105d3c

Browse files
committed
add method create_nsite_applicator to PBMApplicator
1 parent 0287718 commit 4105d3c

File tree

4 files changed

+93
-21
lines changed

4 files changed

+93
-21
lines changed

src/ComponentArrayInterpreter.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ function ComponentArrayInterpreter(
171171
ComponentArrayInterpreter(n_dims, CA.getaxes(ca), m_dims)
172172
end
173173

174+
174175
function ComponentArrayInterpreter(
175176
n_dims::NTuple{N,<:Integer}, axes::NTuple{A,<:CA.AbstractAxis},
176177
m_dims::NTuple{M,<:Integer}) where {N,A,M}

src/HybridVariationalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("ModelApplicator.jl")
5555

5656
export AbstractPBMApplicator, NullPBMApplicator, PBMSiteApplicator, PBMPopulationApplicator
5757
export DirectPBMApplicator, PBMPopulationGlobalApplicator
58+
export create_nsite_applicator
5859
include("PBMApplicator.jl")
5960

6061
# export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler

src/PBMApplicator.jl

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ function (app::AbstractPBMApplicator)(θP::AbstractArray, θMs::AbstractArray, x
2828
apply_model(app, θP, θMs, xP)
2929
end
3030

31+
function create_nsite_applicator(app::AbstractPBMApplicator, n_site)
32+
copy(app)
33+
end
34+
35+
3136
"""
3237
apply_model(app::AbstractPBMApplicator, θsP::AbstractVector, θsMs::AbstractMatrix, xP::AbstractMatrix)
3338
apply_model(app::AbstractPBMApplicator, θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, xP)
@@ -114,7 +119,7 @@ struct PBMSiteApplicator{F, IT, IXT, VFT} <: AbstractPBMApplicator
114119
end
115120

116121
"""
117-
PBMSiteApplicator(fθ, n_batch; θP, θM, θFix, xPvec)
122+
PBMSiteApplicator(fθ; θP, θM, θFix, xPvec)
118123
119124
Construct AbstractPBMApplicator from process-based model `fθ` that computes predictions
120125
for a single site.
@@ -189,10 +194,10 @@ end
189194
@functor PBMPopulationApplicator (θFixm, rep_fac)
190195

191196
"""
192-
PBMPopulationApplicator(fθpop, n_batch; θP, θM, θFix, xPvec)
197+
PBMPopulationApplicator(fθpop, n_site; θP, θM, θFix, xPvec)
193198
194199
Construct AbstractPBMApplicator from process-based model `fθ` that computes predictions
195-
across sites for a population of size `n_batch`.
200+
across sites for a population of size `n_site`.
196201
The applicator combines enclosed `θFix`, with provided `θMs` and `θP`
197202
to a `ComponentMatrix` with parameters with one row for each site, that
198203
can be column-indexed by Symbols.
@@ -203,29 +208,41 @@ can be column-indexed by Symbols.
203208
- `θc`: parameters: `ComponentMatrix` (n_site x n_par) with each row a parameter vector
204209
- `xPc`: observations: `ComponentMatrix` (n_obs x n_site) with each column
205210
observationsfor one site
206-
- `n_batch`: number of indiduals, i.e. rows in `θMs`
211+
- `n_site`: number of indiduals, i.e. rows in `θMs`
207212
- `θP`: `ComponentVector` template of global process model parameters
208213
- `θM`: `ComponentVector` template of individual process model parameters
209214
- `θFix`: `ComponentVector` of actual fixed process model parameters
210215
- `xPvec`: `ComponentVector` template of model drivers for a single site
211216
"""
212-
function PBMPopulationApplicator(fθpop, n_batch;
217+
function PBMPopulationApplicator(fθpop, n_site;
213218
θP::CA.ComponentVector, θM::CA.ComponentVector, θFix::CA.ComponentVector,
214219
xPvec::CA.ComponentVector
215220
)
216221
intθvec = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
217222
int_xP_vec = ComponentArrayInterpreter(xPvec)
218-
isFix = repeat(axes(θFix, 1)', n_batch)
223+
isFix = repeat(axes(θFix, 1)', n_site)
219224
#
220-
intθ = get_concrete(ComponentArrayInterpreter((n_batch,), intθvec))
221-
int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_batch,)))
222-
#isP = repeat(axes(θP, 1)', n_batch)
225+
intθ = get_concrete(ComponentArrayInterpreter((n_site,), intθvec))
226+
int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_site,)))
227+
#isP = repeat(axes(θP, 1)', n_site)
223228
# n_site = size(θMs, 1)
224-
rep_fac = ones_similar_x(θP, n_batch) # to reshape into matrix, avoiding repeat
229+
rep_fac = ones_similar_x(θP, n_site) # to reshape into matrix, avoiding repeat
225230
θFixm = CA.ComponentMatrix(θFix[isFix], (CA.FlatAxis(), CA.getaxes(θFix)[1]))
226231
PBMPopulationApplicator(fθpop, θFixm, rep_fac, intθ, int_xP)
227232
end
228233

234+
function create_nsite_applicator(app::PBMPopulationApplicator, n_site)
235+
θFix = app.θFixm[1,:]
236+
isFix = repeat(axes(θFix, 1)', n_site)
237+
θFixm = CA.ComponentMatrix(θFix[isFix], (CA.FlatAxis(), CA.getaxes(θFix)[1]))
238+
#
239+
intθ = get_concrete(ComponentArrayInterpreter((n_site,), (CA.getaxes(app.intθ)[2],),()))
240+
int_xP = get_concrete(ComponentArrayInterpreter(
241+
(), (CA.getaxes(app.int_xP)[1],), (n_site,)))
242+
rep_fac = ones_similar_x(θFix, n_site) # to reshape into matrix, avoiding repeat
243+
PBMPopulationApplicator(app.fθpop, θFixm, rep_fac, intθ, int_xP)
244+
end
245+
229246
function apply_model(app::PBMPopulationApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP)
230247
if (CA.getdata(θP) isa GPUArraysCore.AbstractGPUArray) &&
231248
(!(CA.getdata(app.θFixm) isa GPUArraysCore.AbstractGPUArray) ||
@@ -262,10 +279,10 @@ end
262279

263280

264281
"""
265-
PBMPopulationGlobalApplicator(fθpop, n_batch; θP, θM, θFix, xPvec)
282+
PBMPopulationGlobalApplicator(fθpop, n_site; θP, θM, θFix, xPvec)
266283
267284
Construct AbstractPBMApplicator from process-based model `fθ` that computes predictions
268-
across sites for a population of size `n_batch`.
285+
across sites for a population of size `n_site`.
269286
The applicator combines enclosed `θFix`, with provided `θMs` and `θP`
270287
to a `ComponentMatrix` with parameters with one row for each site, that
271288
can be column-indexed by Symbols.
@@ -278,24 +295,33 @@ can be column-indexed by Symbols.
278295
- `θgc`: parameters: `ComponentVector` (n_par_global)
279296
- `xPc`: observations: `ComponentMatrix` (n_obs x n_site) with each column
280297
observationsfor one site
281-
- `n_batch`: number of indiduals, i.e. rows in `θMs`
298+
- `n_site`: number of indiduals, i.e. rows in `θMs`
282299
- `θP`: `ComponentVector` template of global process model parameters
283300
- `θM`: `ComponentVector` template of individual process model parameters
284301
- `θFix`: `ComponentVector` of actual fixed process model parameters
285302
- `xPvec`: `ComponentVector` template of model drivers for a single site
286303
"""
287-
function PBMPopulationGlobalApplicator(fθpop, n_batch;
304+
function PBMPopulationGlobalApplicator(fθpop, n_site;
288305
θP::CA.ComponentVector, θM::CA.ComponentVector, θFix::CA.ComponentVector,
289306
xPvec::CA.ComponentVector
290307
)
291-
intθvec = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
308+
#intθvec = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
292309
int_xP_vec = ComponentArrayInterpreter(xPvec)
293-
intθs = get_concrete(ComponentArrayInterpreter((n_batch,), θM))
310+
intθs = get_concrete(ComponentArrayInterpreter((n_site,), θM))
294311
intθg = get_concrete(ComponentArrayInterpreter(vcat(θP, θFix)))
295-
int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_batch,)))
312+
int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_site,)))
296313
PBMPopulationGlobalApplicator(fθpop, θFix, intθs, intθg, int_xP)
297314
end
298315

316+
function create_nsite_applicator(app::PBMPopulationGlobalApplicator, n_site)
317+
@info("called PBMPopulationGlobalApplicator.create_nsite_applicator")
318+
intθs = get_concrete(ComponentArrayInterpreter((n_site,), (CA.getaxes(app.intθs)[2],),()))
319+
int_xP = get_concrete(ComponentArrayInterpreter(
320+
(), (CA.getaxes(app.int_xP)[1],), (n_site,)))
321+
PBMPopulationGlobalApplicator(app.fθpop, app.θFix, intθs, app.intθg, int_xP)
322+
end
323+
324+
299325
function apply_model(app::PBMPopulationGlobalApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP)
300326
if (CA.getdata(θP) isa GPUArraysCore.AbstractGPUArray) &&
301327
(!(CA.getdata(app.θFix) isa GPUArraysCore.AbstractGPUArray) ||

test/test_PBMApplicator.jl

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,49 @@ import Zygote
66

77
using MLDataDevices, CUDA, cuDNN, GPUArraysCore
88

9+
f_pop = function(θsc, xPc)
10+
local n_obs = size(xPc, 1)
11+
is_valid = isfinite.(CA.getdata(xPc))
12+
a1 = is_valid .* CA.getdata(θsc[:,:a1])'
13+
a2 = is_valid .* CA.getdata(θsc[:,:a2])'
14+
# b in θP has been expanded in PopulationApplicator
15+
b = is_valid .* CA.getdata(θsc[:,:b])'
16+
y = a1 .+ log.(a2) .* abs2.(cos.(b .- 0.2)) .* xPc.^2
17+
end
18+
19+
() -> begin
20+
include("test/test_scratch.jl")
21+
end
22+
23+
@testset "PBMPopulationApplicator" begin
24+
n_obs = 3
25+
n_site = 5
26+
xPvec = CA.ComponentVector(s1 = 1.0:n_obs)
27+
xPc = xPvec .* ones(n_site)' .+ abs2.(randn(n_obs, n_site) .* 0.1)
28+
θP = CA.ComponentVector(b=3.0)
29+
θM = CA.ComponentVector(a1=2.0,a2=1.0)
30+
θFix = CA.ComponentVector(c=1.5)
31+
#
32+
θMs = (ones(n_site) .* θM') .+ abs2.(randn(n_site, length(θM)) .* 0.1)
33+
θs = hcat(ones(n_site) .* θP', θMs)
34+
y_obs = f_pop(θs, xPc)
35+
g = PBMPopulationApplicator(f_pop, n_site; θP, θM, θFix, xPvec)
36+
ret = g(θP, θMs, xPc)
37+
@test ret y_obs
38+
39+
gr = Zygote.gradient((θP, θMs) -> sum(g(θP, θMs, xPc)), CA.getdata(θP), CA.getdata(θMs))
40+
41+
xPc_NaN = copy(xPc); xPc_NaN[2:3,1] .= NaN
42+
ret2 = g(θP, θMs, xPc_NaN)
43+
gr = Zygote.gradient((θP, θMs) -> sum(g(θP, θMs, xPc_NaN)), CA.getdata(θP), CA.getdata(θMs))
44+
@test all(isfinite.(gr[1]))
45+
46+
n_site_m = 4
47+
gm = create_nsite_applicator(g, n_site_m)
48+
retm = gm(θP, θMs[1:n_site_m,:], xPc[:,1:n_site_m])
49+
@test retm y_obs[:,1:n_site_m]
50+
end;
51+
952
f_global = function(θsc, θgc, xPc)
1053
local n_obs = size(xPc, 1)
1154
#is_dummy = isnan.(CA.getdata(xPc))
@@ -21,10 +64,6 @@ f_global = function(θsc, θgc, xPc)
2164
y = a1 .+ log.(a2) .* abs2.(cos.(b .- 0.2)) .* xPc.^2
2265
end
2366

24-
() -> begin
25-
include("test/test_scratch.jl")
26-
end
27-
2867
@testset "PBMPopulationGlobalApplicator" begin
2968
n_obs = 3
3069
n_site = 5
@@ -47,5 +86,10 @@ end
4786
gr = Zygote.gradient((θP, θMs) -> sum(g(θP, θMs, xPc_NaN)), CA.getdata(θP), CA.getdata(θMs))
4887
@test all(isfinite.(gr[1])) # \thetaP
4988
@test all(isfinite.(gr[2])) # solves finite gradient for a2 for first site
89+
90+
n_site_m = 4
91+
gm = create_nsite_applicator(g, n_site_m)
92+
retm = gm(θP, θMs[1:n_site_m,:], xPc[:,1:n_site_m])
93+
@test retm y_obs[:,1:n_site_m]
5094
end;
5195

0 commit comments

Comments
 (0)