Skip to content

Commit 0f03d3a

Browse files
committed
implement ML model depending on thetaP
1 parent 7192975 commit 0f03d3a

17 files changed

+516
-377
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,30 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2020
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2324
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2425
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2526

2627
[weakdeps]
28+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2729
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2830
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
2931
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
30-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3132

3233
[extensions]
34+
HybridVariationalInferenceCUDAExt = "CUDA"
3335
HybridVariationalInferenceFluxExt = "Flux"
3436
HybridVariationalInferenceLuxExt = "Lux"
3537
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
36-
HybridVariationalInferenceCUDAExt = "CUDA"
3738

3839
[compat]
3940
Bijectors = "0.14, 0.15"
4041
BlockDiagonals = "0.1.42, 0.2"
42+
CUDA = "5.7"
4143
ChainRulesCore = "1.25"
4244
Combinatorics = "1.0.2"
4345
CommonSolve = "0.2.4"
4446
ComponentArrays = "0.15.19"
45-
CUDA = "5.7"
4647
DistributionFits = "0.3.9"
4748
Distributions = "0.25.117"
4849
Flux = "0.14, 0.15, 0.16"
@@ -56,6 +57,7 @@ Optimization = "3.19.3, 4"
5657
Random = "1.10.0"
5758
SimpleChains = "0.4"
5859
StableRNGs = "1.0.2"
60+
StaticArrays = "1.9.13"
5961
StatsBase = "0.34.4"
6062
StatsFuns = "1.3.2"
6163
julia = "1.10"

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ end
129129
() -> begin # optimized loss is indeed lower than with true parameters
130130
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
131131
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
132-
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.f, Float32[], int_ϕθP)
132+
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.transP, prob0.f, Float32[], int_ϕθP)
133133
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc, i_sites)[1]
134134
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc, i_sites)[1]
135135
#

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,18 @@ function HVI.construct_3layer_MLApplicator(
4242
(;θM) = get_hybridproblem_par_templates(prob; scenario)
4343
n_out = length(θM)
4444
n_covar = get_hybridproblem_n_covar(prob; scenario)
45+
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
46+
n_input = n_covar + n_pbm_covars
4547
#(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario)
4648
float_type = get_hybridproblem_float_type(prob; scenario)
4749
is_using_dropout = :use_dropout scenario
4850
is_using_dropout && error("dropout scenario not supported with Flux yet.")
4951
g_chain = Flux.Chain(
5052
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
51-
Flux.Dense(n_covar => n_covar * 4, tanh),
52-
Flux.Dense(n_covar * 4 => n_covar * 4, tanh),
53+
Flux.Dense(n_input => n_input * 4, tanh),
54+
Flux.Dense(n_input * 4 => n_input * 4, tanh),
5355
# dense layer without bias that maps to n outputs and `logistic` activation
54-
Flux.Dense(n_covar * 4 => n_out, logistic, bias = false)
56+
Flux.Dense(n_input * 4 => n_out, logistic, bias = false)
5557
)
5658
construct_ChainsApplicator(rng, g_chain, float_type)
5759
end

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,29 @@ function HVI.construct_3layer_MLApplicator(
2121
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains};
2222
scenario::NTuple = ())
2323
n_covar = get_hybridproblem_n_covar(prob; scenario)
24+
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
25+
n_input = n_covar + n_pbm_covars
2426
FloatType = get_hybridproblem_float_type(prob; scenario)
2527
(;θM) = get_hybridproblem_par_templates(prob; scenario)
2628
n_out = length(θM)
2729
is_using_dropout = :use_dropout scenario
2830
g_chain = if is_using_dropout
2931
SimpleChain(
30-
static(n_covar), # input dimension (optional)
32+
static(n_input), # input dimension (optional)
3133
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
32-
TurboDense{true}(tanh, n_covar * 4),
34+
TurboDense{true}(tanh, n_input * 4),
3335
SimpleChains.Dropout(0.2), # dropout layer
34-
TurboDense{true}(tanh, n_covar * 4),
36+
TurboDense{true}(tanh, n_input * 4),
3537
SimpleChains.Dropout(0.2),
3638
# dense layer without bias that maps to n outputs and `logistic` activation
3739
TurboDense{false}(logistic, n_out)
3840
)
3941
else
4042
SimpleChain(
41-
static(n_covar), # input dimension (optional)
43+
static(n_input), # input dimension (optional)
4244
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
43-
TurboDense{true}(tanh, n_covar * 4),
44-
TurboDense{true}(tanh, n_covar * 4),
45+
TurboDense{true}(tanh, n_input * 4),
46+
TurboDense{true}(tanh, n_input * 4),
4547
# dense layer without bias that maps to n outputs and `logistic` activation
4648
TurboDense{false}(logistic, n_out)
4749
)

src/AbstractHybridProblem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ optionally
1818
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
1919
- `get_hybridproblem_cor_ends` (defaults to include all correlations:
2020
`(P = [length(θP)], M = [length(θM)])`)
21+
- `get_hybridproblem_pbmpar_covars` (defaults to empty tuple)
22+
2123
2224
The initial value of parameters to estimate is spread
2325
- `ϕg`: parameter of the MLapplicator: returned by `get_hybridproblem_MLapplicator`
@@ -117,6 +119,11 @@ function get_hybridproblem_n_covar(::AbstractHybridProblem; scenario) end
117119
# return (n_covar)
118120
# end
119121

122+
123+
function get_hybridproblem_pbmpar_covars(::AbstractHybridProblem; scenario)
124+
()
125+
end
126+
120127
"""
121128
get_hybridproblem_n_site(::AbstractHybridProblem; scenario)
122129

src/DoubleMM/DoubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Bijectors
1212
using Distributions, DistributionFits
1313
using MLDataDevices
1414
import GPUArraysCore # used in conditional breakpoints
15-
15+
import StableRNGs
1616

1717
export f_doubleMM, xP_S1, xP_S2
1818
include("f_doubleMM.jl")

src/DoubleMM/f_doubleMM.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ function HVI.get_hybridproblem_priors(::DoubleMMCase; scenario = ())
4141
Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
4242
end
4343

44+
function HVI.get_hybridproblem_MLapplicator(prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
45+
rng = StableRNGs.StableRNG(111)
46+
get_hybridproblem_MLapplicator(rng, prob; scenario)
47+
end
48+
4449
function HVI.get_hybridproblem_MLapplicator(
4550
rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
4651
ml_engine = select_ml_engine(; scenario)
@@ -53,6 +58,14 @@ function HVI.get_hybridproblem_MLapplicator(
5358
return g, ϕ_g0
5459
end
5560

61+
function HVI.get_hybridproblem_pbmpar_covars(::DoubleMMCase; scenario)
62+
if (:covarK2 scenario)
63+
return (:K2,)
64+
end
65+
()
66+
end
67+
68+
5669
function HVI.get_hybridproblem_transforms(::DoubleMMCase; scenario::NTuple = ())
5770
if (:stackedMS scenario)
5871
return ((; transP, transM = transMS))

src/HybridProblem.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct HybridProblem <: AbstractHybridProblem
1313
get_train_loader::Any
1414
n_covar::Int
1515
n_site::Int
16+
pbm_covars::NTuple
1617
# inner constructor to constrain the types
1718
function HybridProblem(
1819
θP::CA.ComponentVector, θM::CA.ComponentVector,
@@ -28,10 +29,11 @@ struct HybridProblem <: AbstractHybridProblem
2829
n_covar::Int,
2930
n_site::Int,
3031
cor_ends::NamedTuple = (P = [length(θP)], M = [length(θM)]),
31-
)
32+
pbm_covars::NTuple{N,Symbol} = (),
33+
) where N
3234
new(
3335
θP, θM, f, g, ϕg, ϕunc, priors, py, transM, transP, cor_ends, get_train_loader,
34-
n_covar, n_site)
36+
n_covar, n_site, pbm_covars)
3537
end
3638
end
3739

@@ -57,11 +59,12 @@ function HybridProblem(prob::AbstractHybridProblem; scenario = ())
5759
end
5860
end
5961
cor_ends = get_hybridproblem_cor_ends(prob; scenario)
62+
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
6063
priors = get_hybridproblem_priors(prob; scenario)
6164
n_covar = get_hybridproblem_n_covar(prob; scenario)
6265
n_site = get_hybridproblem_n_site(prob; scenario)
6366
HybridProblem(θP, θM, g, ϕg, ϕunc, f, priors, py, transP, transM, get_train_loader,
64-
n_covar, n_site, cor_ends)
67+
n_covar, n_site, cor_ends, pbm_covars)
6568
end
6669

6770
function update(prob::HybridProblem;
@@ -76,12 +79,13 @@ function update(prob::HybridProblem;
7679
transM::Union{Function, Bijectors.Transform} = prob.transM,
7780
transP::Union{Function, Bijectors.Transform} = prob.transP,
7881
cor_ends::NamedTuple = prob.cor_ends,
82+
pbm_covars::NTuple{N,Symbol} = prob.pbm_covars,
7983
get_train_loader::Function = prob.get_train_loader,
8084
n_covar::Integer = prob.n_covar,
8185
n_site::Integer = prob.n_site
82-
)
86+
) where N
8387
HybridProblem(θP, θM, g, ϕg, ϕunc, f, priors, py, transP, transM, get_train_loader,
84-
n_covar, n_site, cor_ends)
88+
n_covar, n_site, cor_ends, pbm_covars)
8589
end
8690

8791
function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ())
@@ -121,6 +125,9 @@ end
121125
function get_hybridproblem_cor_ends(prob::HybridProblem; scenario = ())
122126
prob.cor_ends
123127
end
128+
function get_hybridproblem_pbmpar_covars(prob::HybridProblem; scenario = ())
129+
prob.pbm_covars
130+
end
124131
function get_hybridproblem_n_covar(prob::HybridProblem; scenario = ())
125132
prob.n_covar
126133
end

src/HybridSolver.jl

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,38 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
1717
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario)
1818
FT = get_hybridproblem_float_type(prob; scenario)
1919
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
20-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
21-
ϕg = 1:length(ϕg0), θP = par_templates.θP))
22-
#p0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true
23-
p0_cpu = vcat(ϕg0, par_templates.θP)
24-
p0 = p0_cpu
25-
g_dev = g
20+
intϕ = ComponentArrayInterpreter(CA.ComponentVector(
21+
ϕg = 1:length(ϕg0), ϕP = par_templates.θP))
22+
#ϕ0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true
23+
ϕ0_cpu = vcat(ϕg0, apply_preserve_axes(inverse(transP),par_templates.θP))
2624
if gdev isa MLDataDevices.AbstractGPUDevice
27-
p0 = gdev(p0_cpu)
25+
ϕ0_dev = gdev(ϕ0_cpu)
2826
g_dev = gdev(g)
27+
else
28+
ϕ0_dev = ϕ0_cpu
29+
g_dev = g
2930
end
3031
train_loader = get_hybridproblem_train_dataloader(
3132
prob; scenario, n_batch = solver.n_batch)
3233
f = get_hybridproblem_PBmodel(prob; scenario)
3334
y_global_o = FT[] # TODO
34-
loss_gf = get_loss_gf(g_dev, transM, f, y_global_o, int_ϕθP; cdev)
35+
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
36+
#intP = ComponentArrayInterpreter(par_templates.θP)
37+
loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; cdev, pbm_covars)
3538
# call loss function once
36-
l1 = loss_gf(p0, first(train_loader)...)[1]
39+
l1 = loss_gf(ϕ0_dev, first(train_loader)...)[1]
3740
# and gradient
3841
# xMg, xP, y_o, y_unc = first(train_loader)
3942
# gr1 = Zygote.gradient(
4043
# p -> loss_gf(p, xMg, xP, y_o, y_unc)[1],
41-
# p0)
44+
# ϕ0_dev)
4245
# data1 = first(train_loader)
43-
# Zygote.gradient(p0 -> loss_gf(p0, data1...)[1], p0)
46+
# Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev)
4447
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
4548
Optimization.AutoZygote())
46-
optprob = OptimizationProblem(optf, CA.getdata(p0), train_loader)
49+
optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader)
4750
res = Optimization.solve(optprob, solver.alg; kwargs...)
48-
(; ϕ = int_ϕθP(res.u), resopt = res)
51+
(; ϕ = intϕ(res.u), resopt = res)
4952
end
5053

5154
struct HybridPosteriorSolver{A} <: AbstractHybridSolver
@@ -77,6 +80,7 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
7780
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario)
7881
ϕunc0 = get_hybridproblem_ϕunc(prob; scenario)
7982
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
83+
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
8084
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
8185
θP, θM, cor_ends, ϕg0, solver.n_batch; transP, transM, ϕunc0)
8286
if gdev isa MLDataDevices.AbstractGPUDevice
@@ -90,12 +94,12 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
9094
f = get_hybridproblem_PBmodel(prob; scenario)
9195
py = get_hybridproblem_neg_logden_obs(prob; scenario)
9296
priors_θ_mean = construct_priors_θ_mean(
93-
prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM;
94-
scenario, get_ca_int_PMs, cdev)
97+
prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM, transP;
98+
scenario, get_ca_int_PMs, cdev, pbm_covars)
9599
y_global_o = Float32[] # TODO
96100
loss_elbo = get_loss_elbo(
97101
g_dev, transPMs_batch, f, py, y_global_o, interpreters;
98-
solver.n_MC, solver.n_MC_cap, cor_ends, priors_θ_mean, cdev)
102+
solver.n_MC, solver.n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covars, θP)
99103
# test loss function once
100104
l0 = loss_elbo(ϕ0_dev, rng, first(train_loader)...)
101105
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1],
@@ -116,28 +120,32 @@ end
116120

117121
"""
118122
Create a loss function for parameter vector ϕ, given
119-
- g(x, ϕ): machine learning model
120-
- transPMS: transformation from unconstrained space to parameter space
121-
- f(θMs, θP): mechanistic model
122-
- interpreters: assigning structure to pure vectors, see neg_elbo_gtf
123-
- n_MC: number of Monte-Carlo sample to approximate the expected value across distribution
123+
- `g(x, ϕ)`: machine learning model
124+
- `transPMS`: transformation from unconstrained space to parameter space
125+
- `f(θMs, θP)`: mechanistic model
126+
- `interpreters`: assigning structure to pure vectors, see `neg_elbo_gtf`
127+
- `n_MC`: number of Monte-Carlo sample to approximate the expected value across distribution
128+
- `pbm_covars`: tuple of symbols of process-based parameters provided to the ML model
129+
- `θP`: CompoenntVector as a template to select indices of pbm_covars
124130
125131
The loss function takes in addition to ϕ, data that changes with minibatch
126-
- rng: random generator
127-
- xM: matrix of covariates, sites in columns
128-
- xP: drivers for the processmodel: Iterator of size n_site
129-
- y_o, y_unc: matrix of observations and uncertainties, sites in columns
132+
- `rng`: random generator
133+
- `xM`: matrix of covariates, sites in columns
134+
- `xP`: drivers for the processmodel: Iterator of size n_site
135+
- `y_o`, `y_unc`: matrix of observations and uncertainties, sites in columns
130136
"""
131137
function get_loss_elbo(g, transPMs, f, py, y_o_global, interpreters;
132-
n_MC, n_MC_cap = n_MC, cor_ends, priors_θ_mean, cdev)
138+
n_MC, n_MC_cap = n_MC, cor_ends, priors_θ_mean, cdev, pbm_covars, θP,
139+
)
133140
let g = g, transPMs = transPMs, f = f, py = py, y_o_global = y_o_global, n_MC = n_MC,
134141
cor_ends = cor_ends, interpreters = map(get_concrete, interpreters),
135-
priors_θ_mean = priors_θ_mean, cdev = cdev
142+
priors_θ_mean = priors_θ_mean, cdev = cdev,
143+
pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars)
136144

137145
function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc, i_sites)
138146
neg_elbo_gtf(
139147
rng, ϕ, g, transPMs, f, py, xM, xP, y_o, y_unc, i_sites, interpreters;
140-
n_MC, n_MC_cap, cor_ends, priors_θ_mean, cdev)
148+
n_MC, n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covar_indices)
141149
end
142150
end
143151
end
@@ -183,16 +191,19 @@ end
183191
In order to let mean of θ stay close to initial point parameter estimates
184192
construct a prior on mean θ to a Normal around initial prediction.
185193
"""
186-
function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM;
187-
scenario, get_ca_int_PMs, cdev)
194+
function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM, transP;
195+
scenario, get_ca_int_PMs, cdev, pbm_covars)
188196
iszero(θmean_quant) ? [] :
189197
begin
190198
n_site = get_hybridproblem_n_site(prob; scenario)
191199
all_loader = get_hybridproblem_train_dataloader(prob; scenario, n_batch = n_site)
192200
xM_all = first(all_loader)[1]
193-
θMs = gtrans(g_dev, transM, xM_all, CA.getdata(ϕg); cdev)
194-
priors_dict = get_hybridproblem_priors(prob; scenario)
195201
#Main.@infiltrate_main
202+
ζP = apply_preserve_axes(inverse(transP), θP)
203+
pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars)
204+
xMP_all = _append_each_covars(xM_all, CA.getdata(ζP), pbm_covar_indices)
205+
θMs = gtrans(g_dev, transM, xMP_all, CA.getdata(ϕg); cdev)
206+
priors_dict = get_hybridproblem_priors(prob; scenario)
196207
priorsP = [priors_dict[k] for k in keys(θP)]
197208
priors_θP_mean = map(priorsP, θP) do priorsP, θPi
198209
fit_narrow_normal(θPi, priorsP, θmean_quant)

src/HybridVariationalInference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using CommonSolve
1717
#using OptimizationOptimisers # default alg=Adam(0.02)
1818
using Optimization
1919
using Distributions, DistributionFits
20+
using StaticArrays: StaticArrays as SA
2021
using Functors
2122

2223
export ComponentArrayInterpreter, flatten1, get_concrete
@@ -40,6 +41,7 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_
4041
get_hybridproblem_n_site,
4142
get_hybridproblem_cor_ends,
4243
get_hybridproblem_priors,
44+
get_hybridproblem_pbmpar_covars,
4345
#update,
4446
gen_cov_pred,
4547
construct_dataloader_from_synthetic,

0 commit comments

Comments
 (0)