Skip to content

Commit 357f1ec

Browse files
authored
Merge pull request #14 from EarthyScience/dev
remove get_hybrid_case_sizes and MLEngine from AbstractHybridCase
2 parents 241e54d + 463531b commit 357f1ec

23 files changed

+691
-479
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ jobs:
4949
with:
5050
files: lcov.info
5151
token: ${{ secrets.CODECOV_TOKEN }}
52-
fail_ci_if_error: false
52+
#fail_ci_if_error: false
53+
fail_ci_if_error: true
5354
docs:
5455
name: Documentation
5556
runs-on: ubuntu-latest

dev/doubleMM.jl

Lines changed: 48 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Statistics
77
using ComponentArrays: ComponentArrays as CA
88

99
using SimpleChains
10-
import Flux # to allow for FluxMLEngine and cpu()
10+
import Flux
1111
using MLUtils
1212
import Zygote
1313

@@ -17,41 +17,43 @@ using Bijectors
1717
using UnicodePlots
1818

1919
const case = DoubleMM.DoubleMMCase()
20-
const MLengine = Val(nameof(SimpleChains))
21-
const FluxMLengine = Val(nameof(Flux))
2220
scenario = (:default,)
2321
rng = StableRNG(111)
2422

2523
par_templates = get_hybridcase_par_templates(case; scenario)
2624

27-
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
25+
#n_covar = get_hybridcase_n_covar(case; scenario)
26+
#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
27+
28+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29+
) = gen_hybridcase_synthetic(rng, case; scenario);
30+
31+
n_covar = size(xM,1)
2832

29-
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o
30-
) = gen_hybridcase_synthetic(case, rng; scenario);
3133

3234
#----- fit g to θMs_true
33-
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
35+
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
36+
(; transP, transM) = get_hybridcase_transforms(case; scenario)
3437

35-
function loss_g(ϕg, x, g)
38+
function loss_g(ϕg, x, g, transM)
3639
ζMs = g(x, ϕg) # predict the log of the parameters
37-
θMs = exp.(ζMs)
40+
θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column
3841
loss = sum(abs2, θMs .- θMs_true)
3942
return loss, θMs
4043
end
41-
loss_g(ϕg0, xM, g)
42-
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
44+
loss_g(ϕg0, xM, g, transM)
4345

44-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
46+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1],
4547
Optimization.AutoZygote())
4648
optprob = Optimization.OptimizationProblem(optf, ϕg0);
4749
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800);
4850

4951
ϕg_opt1 = res.u;
50-
loss_g(ϕg_opt1, xM, g)
51-
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
52-
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
52+
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
53+
scatterplot(vec(θMs_true), vec(θMs_pred))
5354

5455
f = get_hybridcase_PBmodel(case; scenario)
56+
py = get_hybridcase_neg_logden_obs(case; scenario)
5557

5658
#----------- fit g and θP to y_o
5759
() -> begin
@@ -62,7 +64,7 @@ f = get_hybridcase_PBmodel(case; scenario)
6264
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true
6365

6466
# Pass the site-data for the batches as separate vectors wrapped in a tuple
65-
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
67+
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch)
6668

6769
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
6870
l1 = loss_gf(p0, train_loader.data...)[1]
@@ -82,15 +84,16 @@ f = get_hybridcase_PBmodel(case; scenario)
8284
end
8385

8486
#---------- HVI
85-
logσ2y = 2 .* log.(σ_o)
8687
n_MC = 3
87-
transP = elementwise(exp)
88-
transM = Stacked(elementwise(identity), elementwise(exp))
88+
(; transP, transM) = get_hybridcase_transforms(case; scenario)
89+
FT = get_hybridcase_float_type(case; scenario)
8990

9091
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
91-
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊);
92+
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
9293
ϕ_true = ϕ
9394

95+
96+
9497
() -> begin
9598
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
9699
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
@@ -149,49 +152,22 @@ transM = Stacked(elementwise(identity), elementwise(exp))
149152
ϕ_true = inverse_ca(trans_gu, ϕt_true)
150153
end
151154

152-
ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕ_true[[:unc]]); # scratch
155+
ϕ_ini0 = ζ = reduce(
156+
vcat, (
157+
ϕ_true[[:μP]] .* FT(0.001), CA.ComponentVector(ϕg = ϕg0), ϕ_true[[:unc]])) # scratch
153158
#
154-
# true values
155-
ϕ_ini = ζ = vcat(ϕ_true[[:μP, :ϕg]] .* 1.2, ϕ_true[[:unc]]); # slight disturbance
159+
ϕ_ini = ζ = reduce(
160+
vcat, (
161+
ϕ_true[[:μP]] .- FT(0.1), ϕ_true[[:ϕg]] .* FT(1.1), ϕ_true[[:unc]])) # slight disturbance
156162
# hardcoded from HMC inversion
157163
ϕ_ini.unc.coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
158164
ϕ_ini.unc.logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
159165
mean_σ_o_MC = 0.006042
160166

161-
# test cost function and gradient
162-
() -> begin
163-
neg_elbo_transnorm_gf(rng, g, f, ϕ_true, y_o[:, 1:n_batch], xM[:, 1:n_batch],
164-
transPMs_batch, map(get_concrete, interpreters);
165-
n_MC = 8, logσ2y)
166-
Zygote.gradient(
167-
ϕ -> neg_elbo_transnorm_gf(
168-
rng, g, f, ϕ, y_o[:, 1:n_batch], xM[:, 1:n_batch],
169-
transPMs_batch, interpreters; n_MC = 8, logσ2y),
170-
CA.getdata(ϕ_true))
171-
end
172-
173-
# optimize using SimpleChains
174-
() -> begin
175-
train_loader = MLUtils.DataLoader((xM, y_o), batchsize = n_batch)
176-
177-
optf = Optimization.OptimizationFunction(
178-
(ϕ, data) -> begin
179-
xM, y_o = data
180-
neg_elbo_transnorm_gf(
181-
rng, g, f, ϕ, y_o, xM, transPMs_batch,
182-
map(get_concrete, interpreters_g); n_MC = 5, logσ2y)
183-
end,
184-
Optimization.AutoZygote())
185-
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader)
186-
res = Optimization.solve(
187-
optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 800)
188-
#optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
189-
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
190-
end
191-
192-
ϕ = ϕ_ini |> Flux.gpu;
167+
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
193168
xM_gpu = xM |> Flux.gpu;
194-
g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
169+
scenario_flux = (scenario..., :use_Flux)
170+
g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux);
195171

196172
# otpimize using LUX
197173
() -> begin
@@ -216,27 +192,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario
216192
g_flux = g_luxs
217193
end
218194

219-
function fcost(ϕ, xM, y_o)
220-
neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o,
221-
xM, transPMs_batch, map(get_concrete, interpreters);
222-
n_MC = 8, logσ2y = logσ2y)
195+
function fcost(ϕ, xM, y_o, y_unc)
196+
neg_elbo_transnorm_gf(rng, CA.getdata(ϕ), g_flux, transPMs_batch, f, py,
197+
xM, xP, y_o, y_unc, map(get_concrete, interpreters);
198+
n_MC = 8)
223199
end
224-
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch])
200+
fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch])
225201
#Zygote.gradient(fcost, ϕ) |> cpu;
226202
gr = Zygote.gradient(fcost,
227-
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
228-
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
203+
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]),
204+
CA.getdata(y_o[:, 1:n_batch]), CA.getdata(y_unc[:, 1:n_batch]));
205+
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)
229206

230-
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
231-
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
207+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
208+
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux))
232209

233210
optf = Optimization.OptimizationFunction(
234211
(ϕ, data) -> begin
235-
xM, y_o = data
236-
fcost(ϕ, xM, y_o)
237-
# neg_elbo_transnorm_gf(
238-
# rng, g_flux, f, ϕ, y_o, xM, transPMs_batch,
239-
# map(get_concrete, interpreters); n_MC = 5, logσ2y)
212+
xM, xP, y_o, y_unc = data
213+
fcost(ϕ, xM, y_o, y_unc)
240214
end,
241215
Optimization.AutoZygote())
242216
optprob = Optimization.OptimizationProblem(
@@ -256,7 +230,7 @@ end
256230
ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu
257231
ϕunc_VI = interpreters.unc(ζ_VIc.unc)
258232

259-
hcat(θP_true, exp.(ζ_VIc.μP))
233+
hcat(log.(θP_true), ϕ_ini.μP, ζ_VIc.μP)
260234
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
261235
#lineplot!(plt, 0.0, 1.1, identity)
262236
#
@@ -266,11 +240,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample
266240
# test predicting correct obs-uncertainty of predictive posterior
267241
n_sample_pred = 200
268242

269-
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters;
243+
y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters;
270244
get_transPMs, get_ca_int_PMs, n_sample_pred);
271245
size(y_pred) # n_obs x n_site, n_sample_pred
272246

273247
σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3);
248+
σ_o = exp.(y_unc[:,1] / 2)
274249

275250
#describe(σ_o_post)
276251
hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)),
@@ -282,7 +257,7 @@ histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_tru
282257

283258
# look at θP, θM1 of first site
284259
intm_PMs_gen = get_ca_int_PMs(n_site)
285-
ζs, _σ = HVI.generate_ζ(rng, g_flux, f, res.u, xM_gpu,
260+
ζs, _σ = HVI.generate_ζ(rng, g_flux, res.u, xM_gpu,
286261
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred);
287262
ζs = ζs |> Flux.cpu;
288263
θPM = vcat(θP_true, θMs_true[:, 1])

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ module HybridVariationalInferenceFluxExt
33
using HybridVariationalInference, Flux
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using ComponentArrays: ComponentArrays as CA
6+
using Random
67

78
struct FluxApplicator{RT} <: AbstractModelApplicator
89
rebuild::RT
910
end
1011

11-
function HVI.construct_FluxApplicator(m::Chain)
12+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
13+
# TODO: care fore rng and float_type
1214
ϕ, rebuild = destructure(m)
1315
FluxApplicator(rebuild), ϕ
1416
end
@@ -26,18 +28,21 @@ function __init__()
2628
HVI.set_default_GPUHandler(FluxGPUDataHandler())
2729
end
2830

29-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
30-
args...; kwargs...)
31-
# constructor with Flux.Chain
32-
g, ϕg = construct_FluxApplicator(g_chain)
33-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
34-
end
31+
# function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
32+
# args...; kwargs...)
33+
# # constructor with Flux.Chain
34+
# g, ϕg = construct_FluxApplicator(g_chain)
35+
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
36+
# end
3537

36-
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
38+
function HVI.construct_3layer_MLApplicator(
39+
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux};
3740
scenario::NTuple = ())
38-
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
39-
FloatType = get_hybridcase_FloatType(case; scenario)
40-
n_out = n_θM
41+
(;θM) = get_hybridcase_par_templates(case; scenario)
42+
n_out = length(θM)
43+
n_covar = get_hybridcase_n_covar(case; scenario)
44+
#(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
45+
float_type = get_hybridcase_float_type(case; scenario)
4146
is_using_dropout = :use_dropout scenario
4247
is_using_dropout && error("dropout scenario not supported with Flux yet.")
4348
g_chain = Flux.Chain(
@@ -47,7 +52,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
4752
# dense layer without bias that maps to n outputs and `identity` activation
4853
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
4954
)
50-
construct_FluxApplicator(g_chain)
55+
construct_ChainsApplicator(rng, g_chain, float_type)
5156
end
5257

5358

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator
1010
int_ϕ::IT
1111
end
1212

13-
function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
14-
ps, st = Lux.setup(Random.default_rng(), m)
13+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type=Float32; device = gpu_device())
14+
ps, st = Lux.setup(rng, m)
1515
ps_ca = float_type.(CA.ComponentArray(ps))
1616
st = st |> device
1717
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
@@ -25,11 +25,12 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ)
2525
app.stateful_layer(x, ϕc)
2626
end
2727

28-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
29-
args...; device = gpu_device(), kwargs...)
30-
# constructor with SimpleChain
31-
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
32-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
33-
end
28+
# function HVI.HybridProblem(rng::AbstractRNG,
29+
# θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
30+
# args...; device = gpu_device(), kwargs...)
31+
# # constructor with SimpleChain
32+
# g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM); device)
33+
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
34+
# end
3435

3536
end # module

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,28 @@ using HybridVariationalInference, SimpleChains
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using StatsFuns: logistic
66
using ComponentArrays: ComponentArrays as CA
7+
using Random
78

89

910

1011
struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
1112
m::MT
1213
end
1314

14-
function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
15-
ϕ = SimpleChains.init_params(m, FloatType);
15+
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::SimpleChain, FloatType=Float32)
16+
ϕ = SimpleChains.init_params(m, FloatType; rng);
1617
SimpleChainsApplicator(m), ϕ
1718
end
1819

1920
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
2021

21-
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
22-
args...; kwargs...)
23-
# constructor with SimpleChain
24-
g, ϕg = construct_SimpleChainsApplicator(g_chain)
25-
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
26-
end
27-
28-
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
29-
scenario::NTuple=())
30-
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
31-
FloatType = get_hybridcase_FloatType(case; scenario)
32-
n_out = n_θM
22+
function HVI.construct_3layer_MLApplicator(
23+
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains};
24+
scenario::NTuple = ())
25+
n_covar = get_hybridcase_n_covar(case; scenario)
26+
FloatType = get_hybridcase_float_type(case; scenario)
27+
(;θM) = get_hybridcase_par_templates(case; scenario)
28+
n_out = length(θM)
3329
is_using_dropout = :use_dropout scenario
3430
g_chain = if is_using_dropout
3531
SimpleChain(
@@ -52,7 +48,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
5248
TurboDense{false}(identity, n_out)
5349
)
5450
end
55-
construct_SimpleChainsApplicator(g_chain, FloatType)
51+
construct_ChainsApplicator(rng, g_chain, FloatType)
5652
end
5753

5854
end # module

0 commit comments

Comments
 (0)