Skip to content

Commit 0b15204

Browse files
committed
fix error in confusing parameter positions
depending on scenario-templates need to construct different Interpreters passed to forward model
1 parent a0bb4f0 commit 0b15204

File tree

6 files changed

+28
-12
lines changed

6 files changed

+28
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ dev/Manifest*.toml
99
tmp/
1010
**/tmp.svg
1111
dev/intermediate/*
12+
dev/tmp.pdf

dev/intermediate/probos.jld2

-38.2 KB
Binary file not shown.

src/AbstractHybridProblem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,11 @@ function get_hybridproblem_cor_ends(prob::AbstractHybridProblem; scenario = ())
227227
pt = get_hybridproblem_par_templates(prob; scenario)
228228
(P = [length(pt.θP)], M = [length(pt.θM)])
229229
end
230+
231+
232+
function setup_PBMpar_interpreter(θP, θM, θall = vcat(θP, θM))
233+
keys_fixed = ((k for k in keys(θall) if (k keys(θP)) & (k keys(θM)))...,)
234+
θFix = θall[keys_fixed]
235+
intθ = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
236+
intθ, θFix
237+
end

src/DoubleMM/f_doubleMM.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ const transMS = Stacked(elementwise(identity), elementwise(exp))
1313

1414
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
1515

16-
function f_doubleMM::AbstractVector, x)
16+
function f_doubleMM::AbstractVector, x, intθ)
1717
# extract parameters not depending on order, i.e whether they are in θP or θM
1818
y = GPUArraysCore.allowscalar() do
19-
θc = int_θdoubleMM(θ)
19+
θc = intθ(θ)
2020
#using ComponentArrays: ComponentArrays as CA
2121
#r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU
2222
r0 = θc[:r0]
@@ -30,8 +30,10 @@ end
3030

3131
function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ())
3232
if (:omit_r0 scenario)
33+
#return ((; θP = θP_nor0, θM, θf = θP[(:K2r)]))
3334
return ((; θP = θP_nor0, θM))
3435
end
36+
#(; θP, θM, θf = eltype(θP)[])
3537
(; θP, θM)
3638
end
3739

@@ -74,11 +76,10 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = ()
7476
)
7577
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
7678
par_templates = get_hybridproblem_par_templates(prob; scenario)
77-
keys_fixed = ((k for k in keys(θall) if
78-
(k keys(par_templates.θP)) & (k keys(par_templates.θM)))...,)
79-
let θFix = gdev(θall[keys_fixed])
79+
intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
80+
let θFix = gdev(θFix), intθ = get_concrete(intθ)
8081
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
81-
pred_sites = applyf(f_doubleMM, θMs, θP, θFix, x)
82+
pred_sites = applyf(f_doubleMM, θMs, θP, θFix, x, intθ)
8283
pred_global = eltype(pred_sites)[]
8384
return pred_global, pred_sites
8485
end
@@ -101,7 +102,12 @@ const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]
101102
# const xP_S2 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
102103

103104
HVI.get_hybridproblem_n_covar(prob::DoubleMMCase; scenario) = 5
104-
HVI.get_hybridproblem_n_site(prob::DoubleMMCase; scenario) = 800
105+
function HVI.get_hybridproblem_n_site(prob::DoubleMMCase; scenario)
106+
if (:few_sites scenario)
107+
return(100)
108+
end
109+
800
110+
end
105111

106112
function HVI.get_hybridproblem_train_dataloader(prob::DoubleMMCase; scenario = (),
107113
n_batch, rng::AbstractRNG = StableRNG(111), kwargs...

src/HybridVariationalInference.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_
4242
get_hybridproblem_priors,
4343
#update,
4444
gen_cov_pred,
45-
construct_dataloader_from_synthetic
45+
construct_dataloader_from_synthetic,
46+
setup_PBMpar_interpreter
4647
include("AbstractHybridProblem.jl")
4748

4849
export HybridProblem

src/gf.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP)
1+
function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, θFix::AbstractVector, xP, args...; kwargs...)
22
# predict several sites with same global parameters θP and fixed parameters θFix
33
yv = map(eachcol(θMs), xP) do θM, x_site
4-
f(vcat(θP, θM, θFix), x_site)
4+
f(vcat(θP, θM, θFix), x_site, args...; kwargs...)
55
end
66
y = stack(yv)
77
return(y)
88
end
9-
function applyf(f, θMs::AbstractMatrix, θPs::AbstractMatrix, θFix::AbstractVector, xP)
9+
function applyf(f, θMs::AbstractMatrix, θPs::AbstractMatrix, θFix::AbstractVector, xP, args...; kwargs...)
1010
# do not call f with matrix θ, because .* with vectors S1 would go wrong
1111
yv = map(eachcol(θMs), eachcol(θPs), xP) do θM, θP, xP_site
12-
f(vcat(θP, θM, θFix), xP_site)
12+
f(vcat(θP, θM, θFix), xP_site, args...; kwargs...)
1313
end
1414
y = stack(yv)
1515
return(y)

0 commit comments

Comments
 (0)