Skip to content

Commit f4ef19d

Browse files
committed
provide xP (PBM drivers) as a Matrix with dataloader
to simplify recreating the tuples of matrices
1 parent c931205 commit f4ef19d

File tree

7 files changed

+56
-35
lines changed

7 files changed

+56
-35
lines changed

src/AbstractHybridProblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHyb
179179
)
180180
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario)
181181
n_site = size(xM,2)
182-
@assert length(xP) == n_site
182+
@assert size(xP,2) == n_site
183183
@assert size(y_o,2) == n_site
184184
@assert size(y_unc,2) == n_site
185185
i_sites = 1:n_site

src/DoubleMM/f_doubleMM.jl

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ const θall = vcat(θP, θM)
66

77
const θP_nor0 = θP[(:K2,)]
88

9+
const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1]
10+
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]
11+
12+
int_xP1 = ComponentArrayInterpreter(CA.ComponentVector(S1=xP_S1, S2=xP_S2))
13+
914
# const transP = elementwise(exp)
1015
# const transM = elementwise(exp)
1116

@@ -164,20 +169,29 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = ()
164169
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
165170
θFix = repeat(θFix1', n_site_batch)
166171
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
172+
#int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
167173
isP = repeat(axes(par_templates.θP,1)', n_site_batch)
168-
let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP=isP, n_site_batch=n_site_batch
174+
let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP=isP,
175+
n_site_batch=n_site_batch,
176+
#int_xPb=get_concrete(int_xPb),
177+
pos_xP = get_positions(int_xP1)
169178
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
170-
@assert length(xP) == n_site_batch
179+
@assert size(xP,2) == n_site_batch
171180
@assert size(θMs,2) == n_site_batch
172-
# convert vector of tuples to tuple of matricesByRows
173-
# need to supply xP as vectorOfTuples to work with DataLoader
174-
# k = first(keys(xP[1]))
175-
xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
176-
#stack(map(r -> r[k], xP))'
177-
stack(map(r -> r[k], xP); dims = 1)
178-
end)...)
181+
# # convert vector of tuples to tuple of matricesByRows
182+
# # need to supply xP as vectorOfTuples to work with DataLoader
183+
# # k = first(keys(xP[1]))
184+
# xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
185+
# #stack(map(r -> r[k], xP))'
186+
# stack(map(r -> r[k], xP); dims = 1)
187+
# end)...)
179188
#xPM = map(transpose, xPM1)
189+
#xPc = int_xPb(CA.getdata(xP))
190+
#xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
180191
# make sure the same order of columns as in intθ
192+
# reshape big matrix into NamedTuple of drivers S1 and S2
193+
# for broadcasting need sites in rows
194+
xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
181195
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix
182196
θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs)', θFixd)
183197
pred_sites = f_doubleMM(θ, xPM, intθ)'
@@ -202,8 +216,6 @@ end
202216
# return Float32
203217
# end
204218

205-
const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1]
206-
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]
207219

208220
# two observations more?
209221
# const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.1]
@@ -242,7 +254,10 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
242254
# normalize to be distributed around the prescribed true values
243255
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
244256
f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity, use_all_sites = true)
245-
xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site)
257+
#xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site)
258+
int_xPn = ComponentArrayInterpreter(int_xP1, (n_site,))
259+
xP = int_xPn(vcat(repeat(xP_S1,1,n_site),repeat(xP_S2,1,n_site)))
260+
#xP[:S1,:]
246261
θP = par_templates.θP
247262
y_global_true, y_true = f(θP, θMs_true, xP)
248263
σ_o = FloatType(0.01)

src/gf.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; kwargs...)
22
# predict several sites with same global parameters θP and fixed parameters θFix
3+
#θM, x_site = first(zip(eachcol(θMs), xP))
34
yv = map(eachcol(θMs), xP) do θM, x_site
45
f(vcat(θP, θM, θFix), x_site, args...; kwargs...)
56
end

src/logden_normal.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ function neg_logden_indep_normal(obs::AbstractArray, μ::AbstractArray, logσ2::
2424
# optimize argument logσ2 rather than σs for performance
2525
#nlogL = sum(σfac .* (1/2) .* logσ2 .+ (1/2) .* exp.(.- logσ2) .* abs2.(obs .- μ))
2626
# specifying logσ2 instead of σ is not transforming a random variable -> no Jacobian
27-
nlogL = sum(σfac .* logσ2 .+ abs2.(obs .- μ) .* exp.(.-logσ2)) / 2
27+
obs_data = CA.getdata(obs)
28+
μ_data = CA.getdata(μ)
29+
nlogL = sum(σfac .* logσ2 .+ abs2.(obs_data .- μ_data) .* exp.(.-logσ2)) / convert(eltype(μ),2)
2830
return (nlogL)
2931
end
3032
# function neg_logden_indep_normal(obss::AbstractMatrix, preds::AbstractMatrix, logσ2::AbstractVector; kwargs...)

test/test_HybridProblem.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ construct_problem = (;scenario=(:default,)) -> begin
3939
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
4040
return (y)
4141
end
42-
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
43-
pred_sites = applyf(f_doubleMM, θMs, θP, CA.ComponentVector{FT}(), x)
42+
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
43+
#Main.@infiltrate_main
44+
#first(eachcol(xP))
45+
pred_sites = applyf(f_doubleMM, θMs, θP, CA.ComponentVector{FT}(), eachcol(xP))
4446
pred_global = eltype(pred_sites)[]
4547
return pred_global, pred_sites
4648
end
@@ -93,7 +95,7 @@ test_without_flux = (scenario) -> begin
9395
prob = probc = construct_problem(;scenario);
9496
#@descend construct_problem(;scenario)
9597

96-
@testset "n_input and pbm_covars" begin
98+
@testset "n_input and pbm_covars $(last(scenario))" begin
9799
g, ϕ_g = get_hybridproblem_MLapplicator(prob; scenario);
98100
if :covarK2 scenario
99101
@test g.app.m.inputdim == (static(6),) # 5 + 1 (ncovar + n_pbm)
@@ -104,7 +106,7 @@ test_without_flux = (scenario) -> begin
104106
end
105107
end
106108

107-
@testset "loss_gf" begin
109+
@testset "loss_gf $(last(scenario))" begin
108110
#----------- fit g and θP to y_o
109111
rng = StableRNG(111)
110112
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario)
@@ -157,7 +159,7 @@ gdev = gpu_device()
157159
test_with_flux = (scenario) -> begin
158160
prob = probc = construct_problem(;scenario);
159161

160-
@testset "HybridPointSolver" begin
162+
@testset "HybridPointSolver $(last(scenario))" begin
161163
rng = StableRNG(111)
162164
solver = HybridPointSolver(; alg=Adam(0.02))
163165
(; ϕ, resopt, probo) = solve(prob, solver; scenario, rng,
@@ -177,7 +179,7 @@ test_with_flux = (scenario) -> begin
177179
@test ϕ.ϕP.K2 < 1.5 * log(θP.K2)
178180
end;
179181

180-
@testset "HybridPosteriorSolver" begin
182+
@testset "HybridPosteriorSolver $(last(scenario))" begin
181183
rng = StableRNG(111)
182184
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
183185
(; ϕ, θP, resopt) = solve(prob, solver; scenario, rng,
@@ -195,7 +197,7 @@ test_with_flux = (scenario) -> begin
195197
end;
196198

197199
if gdev isa MLDataDevices.AbstractGPUDevice
198-
@testset "HybridPosteriorSolver gpu" begin
200+
@testset "HybridPosteriorSolver gpu $(last(scenario))" begin
199201
scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0)
200202
rng = StableRNG(111)
201203
# here using DoubleMMCase() directly rather than construct_problem
@@ -239,7 +241,7 @@ test_with_flux = (scenario) -> begin
239241
end
240242

241243
end;
242-
@testset "HybridPosteriorSolver also f on gpu" begin
244+
@testset "HybridPosteriorSolver also f on gpu $(last(scenario))" begin
243245
scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu)
244246
rng = StableRNG(111)
245247
probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf);

test/test_doubleMM.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fneglogden = get_hybridproblem_neg_logden_obs(prob; scenario)
4747
vec(mean(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM), rtol = 0.02)
4848
@test isapprox(vec(std(CA.getdata(θMs_true); dims = 2)),
4949
CA.getdata(par_templates.θM) .* 0.1, rtol = 0.02)
50-
@test size(xP) == (n_site,)
50+
@test size(xP) == (16, n_site)
5151
@test size(y_o) == (8, n_site)
5252

5353
# test same results for same rng
@@ -57,9 +57,10 @@ fneglogden = get_hybridproblem_neg_logden_obs(prob; scenario)
5757
end
5858

5959
@testset "f_doubleMM_Matrix" begin
60-
is = repeat(axes(θP_true, 1)', n_site)
60+
is = repeat((1:length(θP_true))', n_site)
6161
θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true)
62-
xPM = map(xP1s -> repeat(xP1s', n_site), xP[1])
62+
#xPM = map(xP1s -> repeat(xP1s', n_site), xP[1])
63+
xPM = (S1 = CA.getdata(xP[:S1,:])', S2 = CA.getdata(xP[:S2,:])')
6364
#θ = hcat(θP_true[is], θMs_true')
6465
intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1])))
6566
#θpos = get_positions(intθ1)
@@ -71,18 +72,18 @@ end
7172
end
7273
y = fy(θvec, xPM)
7374
y_exp = applyf(HVI.DoubleMM.f_doubleMM, θMs_true, θP_true,
74-
Vector{eltype(θP_true)}(undef, 0), xP, intθ1)
75+
Vector{eltype(θP_true)}(undef, 0), eachcol(xP), intθ1)
7576
@test y == y_exp'
7677
ygrad = Zygote.gradient(θv -> sum(fy(θv, xPM)), θvec)[1]
7778
if gdev isa MLDataDevices.AbstractGPUDevice
7879
# θg = gdev(θ)
7980
# xPMg = gdev(xPM)
8081
# yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ);
81-
θvecg = gdev(θvec)
82+
θvecg = gdev(θvec); # errors without ";"
8283
xPMg = gdev(xPM)
8384
yg = fy(θvecg, xPMg)
8485
@test cdev(yg) == y_exp'
85-
ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] # errors without ";"
86+
ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1]
8687
@test ygradg isa CA.ComponentArray
8788
@test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray
8889
ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version
@@ -94,7 +95,7 @@ end
9495
@testset "neg_logden_obs Matrix" begin
9596
is = repeat(axes(θP_true, 1)', n_site)
9697
θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true)
97-
xPM = map(xP1s -> repeat(xP1s', n_site), xP[1])
98+
xPM = (S1 = CA.getdata(xP[:S1,:])', S2 = CA.getdata(xP[:S2,:])')
9899
#θ = hcat(θP_true[is], θMs_true')
99100
intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1])))
100101
#θpos = get_positions(intθ1)
@@ -111,13 +112,13 @@ end
111112
# θg = gdev(θ)
112113
# xPMg = gdev(xPM)
113114
# yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ);
114-
θvecg = gdev(θvec)
115+
θvecg = gdev(θvec);
115116
xPMg = gdev(xPM)
116-
y_og = gdev(y_o)
117+
y_og = gdev(y_o);
117118
y_uncg = gdev(y_unc)
118119
costg = fcost(θvecg, xPMg, y_og, y_uncg)
119120
@test costg cost
120-
ygradg = Zygote.gradient(θv -> fcost(θv, xPMg, y_og, y_uncg), θvecg)[1] # errors without ";"
121+
ygradg = Zygote.gradient(θv -> fcost(θv, xPMg, y_og, y_uncg), θvecg)[1]; # errors without ";"
121122
@test ygradg isa CA.ComponentArray
122123
@test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray
123124
ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version

test/test_elbo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,13 @@ test_scenario = (scenario) -> begin
110110
@testset "neg_elbo_gtf cpu" begin
111111
i_sites = 1:n_batch
112112
cost = neg_elbo_gtf(rng, ϕ_ini, g, transPMs_batch, f, py,
113-
xM[:, i_sites], xP[i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites,
113+
xM[:, i_sites], xP[:,i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites,
114114
map(get_concrete, interpreters);
115115
cor_ends, pbm_covar_indices)
116116
@test cost isa Float64
117117
gr = Zygote.gradient(
118118
ϕ -> neg_elbo_gtf(rng, ϕ, g, transPMs_batch, f, py,
119-
xM[:, i_sites], xP[i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites,
119+
xM[:, i_sites], xP[:,i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites,
120120
map(get_concrete, interpreters);
121121
cor_ends, pbm_covar_indices),
122122
CA.getdata(ϕ_ini))
@@ -128,7 +128,7 @@ test_scenario = (scenario) -> begin
128128
i_sites = 1:n_batch
129129
ϕ = ggdev(CA.getdata(ϕ_ini))
130130
xMg_batch = ggdev(xM[:, i_sites])
131-
xP_batch = xP[i_sites] # used in f which runs on CPU
131+
xP_batch = xP[:,i_sites] # used in f which runs on CPU
132132
cost = neg_elbo_gtf(rng, ϕ, g_gpu, transPMs_batch, f, py,
133133
xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites,
134134
map(get_concrete, interpreters);

0 commit comments

Comments
 (0)