Skip to content

Commit df48ad4

Browse files
committed
reorder arguments of elbo
random generator parameters functions drivers observation and uncertainties
1 parent f4acd16 commit df48ad4

File tree

4 files changed

+33
-29
lines changed

4 files changed

+33
-29
lines changed

dev/doubleMM.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
188188
end
189189

190190
function fcost(ϕ, xM, y_o, y_unc)
191-
neg_elbo_transnorm_gf(rng, g_flux, f, py, CA.getdata(ϕ), y_o, y_unc,
192-
xM, xP, transPMs_batch, map(get_concrete, interpreters);
191+
neg_elbo_transnorm_gf(rng, CA.getdata(ϕ), g_flux, transPMs_batch, f, py,
192+
xM, xP, y_o, y_unc, map(get_concrete, interpreters);
193193
n_MC = 8)
194194
end
195195
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch])

src/elbo.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ expected value of the likelihood of observations.
66
77
## Arguments
88
- `rng`: random number generator (ignored on CUDA, if ϕ is a AbstractGPUArray)
9-
- `ϕ`: flat vector of parameters
10-
including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc),
9+
- `ϕ`: flat vector of parameters, interpreted by interpreters
1110
interpreted by interpreters.μP_ϕg_unc and interpreters.PMs
1211
- `g`: machine learning model
1312
- `transPMs`: Transformations as generated by get_transPMs returned from init_hybrid_params
@@ -18,12 +17,17 @@ expected value of the likelihood of observations.
1817
- `xP`: model drivers, iterable of (n_site_batch)
1918
- `y_ob`: matrix of observations (n_obs x n_site_batch)
2019
- `y_unc`: observation uncertainty provided to py (same size as y_ob)
21-
- interpreters:
20+
- `interpreters`: NamedTuple as generated by `gen_hybridcase_synthetic` with entries:
21+
- `μP_ϕg_unc`: extract components of parameter of
22+
1) means of global PBM, 2) ML-weights, and 3) additional parameters of approximation q
23+
- `PMs`: assign components to PBM parameters 1 global, 2 matrix of n_site column vectors
24+
- `int_unc` (can be omitted, if `μP_ϕg_unc(ϕ).unc` is already a ComponentVector)
2225
- `n_MC`: number of MonteCarlo samples from the distribution of parameters to simulate
2326
using the mechanistic model f.
2427
"""
25-
function neg_elbo_transnorm_gf(rng, g, f, py, ϕ::AbstractVector, y_ob, y_unc,
26-
xM::AbstractMatrix, xP, transPMs, interpreters::NamedTuple;
28+
function neg_elbo_transnorm_gf(rng, ϕ::AbstractVector, g, transPMs, f, py,
29+
xM::AbstractMatrix, xP, y_ob, y_unc,
30+
interpreters::NamedTuple;
2731
n_MC=3, gpu_data_handler = get_default_GPUHandler(),
2832
cor_starts=(P=(1,),M=(1,))
2933
)

test/test_HybridProblem.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ import Flux
116116

117117
py = get_hybridcase_neg_logden_obs(prob)
118118

119-
cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o, y_unc,
120-
xM, xP, transPMs_batch, map(get_concrete, interpreters);
119+
cost = neg_elbo_transnorm_gf(rng, ϕ_ini, g, transPMs_batch, f, py,
120+
xM, xP, y_o, y_unc, map(get_concrete, interpreters);
121121
n_MC=8)
122122
@test cost isa Float64
123123
gr = Zygote.gradient(
124-
ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc,
125-
xM, xP, transPMs_batch, map(get_concrete, interpreters);
126-
n_MC=8),
124+
ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py,
125+
xM, xP, y_o, y_unc, map(get_concrete, interpreters);
126+
n_MC=8),
127127
CA.getdata(ϕ_ini))
128128
@test gr[1] isa Vector
129129

@@ -144,14 +144,14 @@ import Flux
144144
ϕ_ini.ϕg = ϕg0
145145
ϕ = CuArray(CA.getdata(ϕ_ini))
146146
xMg = CuArray(xM)
147-
cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc,
148-
xMg, xP, transPMs_batch, map(get_concrete, interpreters);
147+
cost = neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py,
148+
xMg, xP, y_o, y_unc, map(get_concrete, interpreters);
149149
n_MC=8)
150150
@test cost isa Float64
151151
gr = Zygote.gradient(
152-
ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc,
153-
xMg, xP, transPMs_batch, map(get_concrete, interpreters);
154-
n_MC=8),
152+
ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py,
153+
xMg, xP, y_o, y_unc, map(get_concrete, interpreters);
154+
n_MC=8),
155155
ϕ)
156156
@test gr[1] isa CuVector
157157
@test eltype(gr[1]) == get_hybridcase_float_type(prob)

test/test_elbo.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ if CUDA.functional()
8080
end
8181

8282
@testset "neg_elbo_transnorm_gf cpu" begin
83-
cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
84-
xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters);
83+
cost = neg_elbo_transnorm_gf(rng, ϕ_ini, g, transPMs_batch, f, py,
84+
xM[:, 1:n_batch], xP[1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
85+
map(get_concrete, interpreters);
8586
n_MC = 8)
8687
@test cost isa Float64
8788
gr = Zygote.gradient(
88-
ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
89-
xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters);
89+
ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py,
90+
xM[:, 1:n_batch], xP[1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
91+
map(get_concrete, interpreters);
9092
n_MC = 8),
9193
CA.getdata(ϕ_ini))
9294
@test gr[1] isa Vector
@@ -97,17 +99,15 @@ if CUDA.functional()
9799
ϕ = CuArray(CA.getdata(ϕ_ini))
98100
xMg_batch = CuArray(xM[:, 1:n_batch])
99101
xP_batch = xP[1:n_batch] # used in f which runs on CPU
100-
cost = neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ,
101-
y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
102-
xMg_batch, xP_batch,
103-
transPMs_batch, map(get_concrete, interpreters);
102+
cost = neg_elbo_transnorm_gf(rng, ϕ, g_flux, transPMs_batch, f, py,
103+
xMg_batch, xP_batch, y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
104+
map(get_concrete, interpreters);
104105
n_MC = 8)
105106
@test cost isa Float64
106107
gr = Zygote.gradient(
107-
ϕ -> neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ,
108-
y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
109-
xMg_batch, xP_batch,
110-
transPMs_batch, map(get_concrete, interpreters);
108+
ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g_flux, transPMs_batch, f, py,
109+
xMg_batch, xP_batch, y_o[:, 1:n_batch], y_unc[:, 1:n_batch],
110+
map(get_concrete, interpreters);
111111
n_MC = 8),
112112
ϕ)
113113
@test gr[1] isa CuVector

0 commit comments

Comments
 (0)