Skip to content

Commit e731305

Browse files
committed
move S1 and S2 in doubleMM problem to drivers
1 parent cdbfa10 commit e731305

File tree

11 files changed

+87
-79
lines changed

11 files changed

+87
-79
lines changed

src/DoubleMM/DoubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ using StatsFuns: logistic
99
using Bijectors
1010

1111

12+
export f_doubleMM, xP_S1, xP_S2
1213
include("f_doubleMM.jl")
1314

14-
export f_doubleMM, S1, S2
1515

1616
end

src/DoubleMM/f_doubleMM.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
struct DoubleMMCase <: AbstractHybridCase end
22

3-
const S1 = [1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
4-
const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
53

6-
θP = CA.ComponentVector(r0 = 0.3, K2 = 2.0)
7-
θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
4+
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
5+
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
86

97
transP = elementwise(exp)
108
transM = Stacked(elementwise(identity), elementwise(exp))
119

1210

1311
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
1412

15-
function f_doubleMM::AbstractVector)
13+
function f_doubleMM::AbstractVector, x)
1614
# extract parameters not depending on order, i.e whether they are in θP or θM
1715
θc = int_θdoubleMM(θ)
1816
r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)]
19-
y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
17+
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
2018
return (y)
2119
end
2220

@@ -40,17 +38,20 @@ function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
4038
end
4139

4240
function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
43-
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
41+
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
4442
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
45-
pred_sites = applyf(fsite, θMs, θP, x)
43+
pred_sites = applyf(f_doubleMM, θMs, θP, x)
4644
pred_global = eltype(pred_sites)[]
4745
return pred_global, pred_sites
4846
end
4947
end
5048

51-
function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
52-
return Float32
53-
end
49+
# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario)
50+
# return Float32
51+
# end
52+
53+
const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
54+
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
5455

5556
function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
5657
scenario = ())
@@ -62,14 +63,14 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
6263
rhodec = 8, is_using_dropout = false)
6364
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
6465
# normalize to be distributed around the prescribed true values
65-
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1))
66+
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
6667
f = get_hybridcase_PBmodel(case; scenario)
67-
xP = fill((), n_site)
68-
y_global_true, y_true = f(θP, θMs_true, zip())
69-
σ_o = 0.01
68+
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
69+
y_global_true, y_true = f(θP, θMs_true, xP)
70+
σ_o = FloatType(0.01)
7071
#σ_o = 0.002
71-
y_global_o = y_global_true .+ randn(rng, size(y_global_true)) .* σ_o
72-
y_o = y_true .+ randn(rng, size(y_true)) .* σ_o
72+
y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o
73+
y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o
7374
(;
7475
xM,
7576
n_site,
@@ -83,3 +84,4 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
8384
σ_o = fill(σ_o, size(y_true,1)),
8485
)
8586
end
87+

src/HybridProblem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::N
3737
prob.g, prob.ϕg
3838
end
3939

40-
function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
41-
eltype(prob.θM)
42-
end
40+
# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
41+
# eltype(prob.θM)
42+
# end
4343

4444

4545

src/elbo.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@ expected value of the likelihood of observations.
1212
including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc),
1313
interpreted by interpreters.μP_ϕg_unc and interpreters.PMs
1414
- y_ob: matrix of observations (n_obs x n_site_batch)
15-
- x: matrix of covariates (n_cov x n_site_batch)
15+
- xM: matrix of covariates (n_cov x n_site_batch)
16+
- xP: model drivers, iterable of (n_site_batch)
1617
- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params
1718
- n_MC: number of MonteCarlo samples from the distribution of parameters to simulate
1819
using the mechanistic model f.
1920
- logσ2y: observation uncertainty (log of the variance)
2021
"""
21-
function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractMatrix,
22-
transPMs, interpreters::NamedTuple;
22+
function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::AbstractMatrix,
23+
xP, transPMs, interpreters::NamedTuple;
2324
n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(),
2425
entropyN = 0.0,
2526
)
26-
ζs, σ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC)
27+
ζs, σ = generate_ζ(rng, g, f, ϕ, xM, interpreters; n_MC)
2728
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
2829
#ζi = first(eachcol(ζs_cpu))
2930
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
30-
y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs)
31+
y_pred_i, logjac = predict_y(ζi, xP, f, transPMs, interpreters.PMs)
3132
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
3233
nLy1 - logjac
3334
end) / n_MC
@@ -45,7 +46,7 @@ end
4546
4647
Prediction function for hybrid model. Returns an Array `(n_obs, n_site, n_sample_pred)`.
4748
"""
48-
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters;
49+
function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters;
4950
get_transPMs, get_ca_int_PMs, n_sample_pred=200,
5051
gpu_data_handler=get_default_GPUHandler())
5152
n_site = size(xM, 2)
@@ -56,7 +57,7 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret
5657
interpreters_gen; n_MC = n_sample_pred)
5758
ζs_cpu = gpu_data_handler(ζs) #
5859
y_pred = stack(map-> first(predict_y(
59-
ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu)));
60+
ζ, xP, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu)));
6061
y_pred
6162
end
6263

@@ -68,19 +69,19 @@ Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0`
6869
to the means extracted from parameters and predicted by the machine learning
6970
model.
7071
"""
71-
function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix,
72+
function generate_ζ(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix,
7273
interpreters::NamedTuple; n_MC=3)
7374
# see documentation of neg_elbo_transnorm_gf
7475
ϕc = interpreters.μP_ϕg_unc(CA.getdata(ϕ))
7576
μ_ζP = ϕc.μP
7677
ϕg = ϕc.ϕg
77-
μ_ζMs0 = g(x, ϕg) # TODO provide μ_ζP to g
78+
μ_ζMs0 = g(xM, ϕg) # TODO provide μ_ζP to g
7879
ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC)
7980
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
8081
ζ = stack(map(eachcol(ζ_resid)) do r
8182
rc = interpreters.PMs(r)
8283
ζP = μ_ζP .+ rc.P
83-
μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g
84+
μ_ζMs = μ_ζMs0 # g(xM, ϕc.ϕ) # TODO provide ζP to g
8485
ζMs = μ_ζMs .+ rc.Ms
8586
vcat(ζP, vec(ζMs))
8687
end)
@@ -168,13 +169,13 @@ Steps:
168169
- transform the parameters to original constrained space
169170
- Applies the mechanistic model for each site
170171
"""
171-
function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter)
172+
function predict_y(ζi, xP, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter)
172173
# θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating
173174
# θc = CA.ComponentVector(θtup)
174175
θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating
175176
θc = int_PMs(θ)
176177
# TODO provide xP
177-
xP = fill((), size(θc.Ms,2))
178+
# xP = fill((), size(θc.Ms,2))
178179
y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU
179180
# TODO take care of y_pred_global
180181
y_pred, logjac

src/gf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, x)
2-
# predict several sites with same physical parameters
2+
# predict several sites with same global parameters θP
33
yv = map(eachcol(θMs), x) do θM, x_site
44
f(vcat(θP, θM), x_site)
55
end

src/hybrid_case.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ For a specific case, provide functions that specify details
1010
- get_hybridcase_PBmodel
1111
optionally
1212
- gen_hybridcase_synthetic
13-
- get_hybridcase_FloatType (if it should differ from Float32)
13+
- get_hybridcase_FloatType (defaults to eltype(θM))
1414
"""
1515
abstract type AbstractHybridCase end;
1616

@@ -92,8 +92,8 @@ function gen_hybridcase_synthetic end
9292
9393
Determine the FloatType for given Case and scenario, defaults to Float32
9494
"""
95-
function get_hybridcase_FloatType(::AbstractHybridCase; scenario)
96-
return Float32
95+
function get_hybridcase_FloatType(case::AbstractHybridCase; scenario)
96+
return eltype(get_hybridcase_par_templates(case; scenario).θM)
9797
end
9898

9999

src/init_hybrid_params.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ function init_hybrid_params(θP, θM, ϕg, n_batch;
2727
# check translating parameters - can match length?
2828
_ = Bijectors.inverse(transP)(θP)
2929
_ = Bijectors.inverse(transM)(θM)
30+
FT = eltype(θM)
3031
# zero correlation matrices
31-
ρsP = zeros(sum(1:(n_θP - 1)))
32-
ρsM = zeros(sum(1:(n_θM - 1)))
32+
ρsP = zeros(FT, sum(1:(n_θP - 1)))
33+
ρsM = zeros(FT, sum(1:(n_θM - 1)))
3334
ϕunc0 = CA.ComponentVector(;
34-
logσ2_logP = fill(-10.0, n_θP),
35-
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
35+
logσ2_logP = fill(FT(-10.0), n_θP),
36+
coef_logσ2_logMs = reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)),
3637
ρsP,
3738
ρsM)
3839
ϕ = CA.ComponentVector(;

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ end
3333
if GROUP == "All" || GROUP == "Aqua"
3434
#@safetestset "test" include("test/test_aqua.jl")
3535
if VERSION >= VersionNumber("1.11.2")
36+
#@safetestset "test" include("test/test_aqua.jl")
3637
@time @safetestset "test_aqua" include("test_aqua.jl")
3738
end
3839
end

test/test_HybridProblem.jl

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,38 @@ using OptimizationOptimisers
1414

1515
const MLengine = Val(nameof(SimpleChains))
1616

17-
1817
construct_problem = () -> begin
19-
S1 = [1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
20-
S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
2118
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
22-
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
19+
θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2)
2320
transP = elementwise(exp)
2421
transM = Stacked(elementwise(identity), elementwise(exp))
2522
n_covar = 5
2623
n_batch = 10
2724
int_θdoubleMM = get_concrete(ComponentArrayInterpreter(
2825
flatten1(CA.ComponentVector(; θP, θM))))
29-
function f_doubleMM::AbstractVector)
26+
function f_doubleMM::AbstractVector, x)
3027
# extract parameters not depending on order, i.e whether they are in θP or θM
3128
θc = int_θdoubleMM(θ)
3229
r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)]
33-
y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
30+
y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
3431
return (y)
3532
end
36-
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
3733
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
38-
pred_sites = applyf(fsite, θMs, θP, x)
34+
pred_sites = applyf(f_doubleMM, θMs, θP, x)
3935
pred_global = eltype(pred_sites)[]
4036
return pred_global, pred_sites
41-
end
37+
end
4238
n_out = length(θM)
4339
g_chain = SimpleChain(
44-
static(n_covar), # input dimension (optional)
45-
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
46-
TurboDense{true}(tanh, n_covar * 4),
47-
TurboDense{true}(tanh, n_covar * 4),
48-
# dense layer without bias that maps to n outputs and `identity` activation
49-
TurboDense{false}(identity, n_out),
50-
)
40+
static(n_covar), # input dimension (optional)
41+
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
42+
TurboDense{true}(tanh, n_covar * 4),
43+
TurboDense{true}(tanh, n_covar * 4),
44+
# dense layer without bias that maps to n outputs and `identity` activation
45+
TurboDense{false}(identity, n_out)
46+
)
5147
g = construct_SimpleChainsApplicator(g_chain)
52-
ϕg = SimpleChains.init_params(g_chain, eltype(θM));
48+
ϕg = SimpleChains.init_params(g_chain, eltype(θM))
5349
HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, g, ϕg)
5450
end
5551
prob = construct_problem();
@@ -65,7 +61,7 @@ rng = StableRNG(111)
6561
) = gen_hybridcase_synthetic(case_syn, rng; scenario);
6662

6763
@testset "loss_g" begin
68-
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario);
64+
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario)
6965

7066
function loss_g(ϕg, x, g)
7167
ζMs = g(x, ϕg) # predict the log of the parameters
@@ -74,15 +70,15 @@ rng = StableRNG(111)
7470
return loss, θMs
7571
end
7672
loss_g(ϕg0, xM, g)
77-
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
73+
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0)
7874

7975
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
8076
Optimization.AutoZygote())
81-
optprob = Optimization.OptimizationProblem(optf, ϕg0);
77+
optprob = Optimization.OptimizationProblem(optf, ϕg0)
8278
#res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
83-
res = Optimization.solve(optprob, Adam(0.02), maxiters = 600);
79+
res = Optimization.solve(optprob, Adam(0.02), maxiters = 600)
8480

85-
ϕg_opt1 = res.u;
81+
ϕg_opt1 = res.u
8682
pred = loss_g(ϕg_opt1, xM, g)
8783
θMs_pred = pred[2]
8884
#scatterplot(vec(θMs_true), vec(θMs_pred))
@@ -91,12 +87,12 @@ end
9187

9288
@testset "loss_gf" begin
9389
#----------- fit g and θP to y_o
94-
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario);
90+
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario)
9591
f = get_hybridcase_PBmodel(prob; scenario)
9692

9793
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
9894
ϕg = 1:length(ϕg0), θP = par_templates.θP))
99-
p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true
95+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.8) # slightly disturb θP_true
10096

10197
# Pass the site-data for the batches as separate vectors wrapped in a tuple
10298
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
@@ -109,8 +105,8 @@ end
109105
optprob = OptimizationProblem(optf, p0, train_loader)
110106

111107
res = Optimization.solve(
112-
# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
113-
optprob, Adam(0.02), maxiters = 1000);
108+
# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
109+
optprob, Adam(0.02), maxiters = 1000)
114110

115111
l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...)
116112
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)

test/test_doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161
end
6262

6363
@testset "loss_gf" begin
64-
#----------- fit g and θP to y_o
64+
#----------- fit g and θP to y_o (without transformations)
6565
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
6666
f = get_hybridcase_PBmodel(case; scenario)
6767

0 commit comments

Comments
 (0)