Skip to content

Commit 1ccfa4a

Browse files
committed
implement get_hybridcase_train_dataloader
1 parent e731305 commit 1ccfa4a

File tree

8 files changed

+71
-62
lines changed

8 files changed

+71
-62
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1212
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -38,6 +39,7 @@ Flux = "v0.15.2, 0.16"
3839
GPUArraysCore = "0.1, 0.2"
3940
LinearAlgebra = "1.10.0"
4041
Lux = "1.4.2"
42+
MLUtils = "0.4.5"
4143
Random = "1.10.0"
4244
SimpleChains = "0.4"
4345
StatsBase = "0.34.4"

dev/doubleMM.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ gr = Zygote.gradient(fcost,
227227
CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch]));
228228
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...)
229229

230-
train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch)
230+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
231+
train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
231232

232233
optf = Optimization.OptimizationFunction(
233234
(ϕ, data) -> begin

src/DoubleMM/f_doubleMM.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,5 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
8585
)
8686
end
8787

88+
89+

src/HybridProblem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ struct HybridProblem <: AbstractHybridCase
88
f
99
g
1010
ϕg
11+
train_loader
1112
# inner constructor to constrain the types
1213
function HybridProblem(
1314
θP::CA.ComponentVector, θM::CA.ComponentVector,
1415
transM::Union{Function, Bijectors.Transform},
1516
transP::Union{Function, Bijectors.Transform},
1617
n_covar::Integer, n_batch::Integer,
17-
f::Function, g::AbstractModelApplicator, ϕg)
18-
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg)
18+
f::Function, g::AbstractModelApplicator, ϕg, train_loader::DataLoader)
19+
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader)
1920
end
2021
end
2122

@@ -37,6 +38,13 @@ function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::N
3738
prob.g, prob.ϕg
3839
end
3940

41+
function get_hybridcase_train_dataloader(
42+
prob::HybridProblem, rng::AbstractRNG = Random.default_rng();
43+
scenario = ())
44+
return(prob.train_loader)
45+
end
46+
47+
4048
# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ())
4149
# eltype(prob.θM)
4250
# end

src/HybridVariationalInference.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using ChainRulesCore
1111
using Bijectors
1212
using Zygote # Zygote.@ignore CUDA.randn
1313
using BlockDiagonals
14+
using MLUtils # dataloader
1415

1516
export ComponentArrayInterpreter, flatten1, get_concrete
1617
include("ComponentArrayInterpreter.jl")
@@ -23,7 +24,8 @@ export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler
2324
include("GPUDataHandler.jl")
2425

2526
export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic,
26-
get_hybridcase_par_templates, get_hybridcase_transforms, gen_cov_pred
27+
get_hybridcase_par_templates, get_hybridcase_transforms, get_hybridcase_train_dataloader,
28+
gen_cov_pred
2729
include("hybrid_case.jl")
2830

2931
export HybridProblem

src/hybrid_case.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ Type to dispatch constructing data and network structures
33
for different cases of hybrid problem setups
44
55
For a specific case, provide functions that specify details
6-
- get_hybridcase_par_templates
7-
- get_hybridcase_transforms
8-
- get_hybridcase_sizes
9-
- get_hybridcase_MLapplicator
10-
- get_hybridcase_PBmodel
6+
- `get_hybridcase_par_templates`
7+
- `get_hybridcase_transforms`
8+
- `get_hybridcase_sizes`
9+
- `get_hybridcase_MLapplicator`
10+
- `get_hybridcase_PBmodel`
11+
- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
1112
optionally
12-
- gen_hybridcase_synthetic
13-
- get_hybridcase_FloatType (defaults to eltype(θM))
13+
- `gen_hybridcase_synthetic`
14+
- `get_hybridcase_FloatType` (defaults to eltype(θM))
1415
"""
1516
abstract type AbstractHybridCase end;
1617

@@ -96,4 +97,22 @@ function get_hybridcase_FloatType(case::AbstractHybridCase; scenario)
9697
return eltype(get_hybridcase_par_templates(case; scenario).θM)
9798
end
9899

100+
"""
101+
get_hybridcase_train_dataloader(::AbstractHybridCase, rng; scenario)
102+
103+
Return a DataLoader that provides a tuple of
104+
- `xM`: matrix of covariates, with one column per site
105+
- `xP`: Iterator of process-model drivers, with one element per site
106+
- `y_o`: matrix of observations with added noise, with one column per site
107+
"""
108+
function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::AbstractRNG;
109+
scenario = ())
110+
(; xM, xP, y_o) = gen_hybridcase_synthetic(case, rng; scenario)
111+
(; n_batch) = get_hybridcase_sizes(case; scenario)
112+
xM_gpu = :use_flux scenario ? CuArray(xM) : xM
113+
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch)
114+
return(train_loader)
115+
end
116+
117+
99118

test/test_HybridProblem.jl

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -46,76 +46,50 @@ construct_problem = () -> begin
4646
)
4747
g = construct_SimpleChainsApplicator(g_chain)
4848
ϕg = SimpleChains.init_params(g_chain, eltype(θM))
49-
HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, g, ϕg)
49+
#
50+
rng = StableRNG(111)
51+
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
52+
) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng;);
53+
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
54+
HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global,
55+
g, ϕg, train_loader)
5056
end
5157
prob = construct_problem();
52-
case_syn = DoubleMM.DoubleMMCase()
5358
scenario = (:default,)
5459

55-
par_templates = get_hybridcase_par_templates(prob; scenario)
56-
57-
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario)
58-
59-
rng = StableRNG(111)
60-
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
61-
) = gen_hybridcase_synthetic(case_syn, rng; scenario);
62-
63-
@testset "loss_g" begin
64-
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario)
65-
66-
function loss_g(ϕg, x, g)
67-
ζMs = g(x, ϕg) # predict the log of the parameters
68-
θMs = exp.(ζMs)
69-
loss = sum(abs2, θMs .- θMs_true)
70-
return loss, θMs
71-
end
72-
loss_g(ϕg0, xM, g)
73-
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0)
7460

75-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
76-
Optimization.AutoZygote())
77-
optprob = Optimization.OptimizationProblem(optf, ϕg0)
78-
#res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
79-
res = Optimization.solve(optprob, Adam(0.02), maxiters = 600)
80-
81-
ϕg_opt1 = res.u
82-
pred = loss_g(ϕg_opt1, xM, g)
83-
θMs_pred = pred[2]
84-
#scatterplot(vec(θMs_true), vec(θMs_pred))
85-
@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9
86-
end
61+
#(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario)
8762

8863
@testset "loss_gf" begin
8964
#----------- fit g and θP to y_o
9065
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario)
66+
train_loader = get_hybridcase_train_dataloader(prob; scenario)
67+
(xM, xP, y_o) = first(train_loader)
9168
f = get_hybridcase_PBmodel(prob; scenario)
69+
par_templates = get_hybridcase_par_templates(prob; scenario)
9270

9371
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
9472
ϕg = 1:length(ϕg0), θP = par_templates.θP))
9573
p = p0 = vcat(ϕg0, par_templates.θP .* 0.8) # slightly disturb θP_true
9674

9775
# Pass the site-data for the batches as separate vectors wrapped in a tuple
98-
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
9976

77+
y_global_o = Float64[]
10078
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
101-
l1 = loss_gf(p0, train_loader.data...)[1]
79+
l1 = loss_gf(p0, first(train_loader)...)
80+
gr = Zygote.gradient(p -> loss_gf(p, train_loader.data...)[1], p0)
81+
@test gr[1] isa Vector
10282

103-
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
104-
Optimization.AutoZygote())
105-
optprob = OptimizationProblem(optf, p0, train_loader)
106-
107-
res = Optimization.solve(
108-
# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
109-
optprob, Adam(0.02), maxiters = 1000)
83+
() -> begin
84+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
85+
Optimization.AutoZygote())
86+
optprob = OptimizationProblem(optf, p0, train_loader)
11087

111-
l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...)
112-
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)
113-
@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9
88+
res = Optimization.solve(
89+
# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
90+
optprob, Adam(0.02), maxiters = 1000)
11491

115-
() -> begin
116-
scatterplot(vec(θMs_true), vec(θMs_pred))
117-
scatterplot(log.(vec(θMs_true)), log.(vec(θMs_pred)))
118-
scatterplot(vec(y_pred), vec(y_o))
119-
hcat(par_templates.θP, int_ϕθP(p0).θP, int_ϕθP(res.u).θP)
92+
l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...)
93+
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)
12094
end
12195
end

test/test_doubleMM.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ end
7070
p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true
7171

7272
# Pass the site-data for the batches as separate vectors wrapped in a tuple
73-
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
73+
#train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
74+
train_loader = get_hybridcase_train_dataloader(case, rng; scenario)
7475

7576
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
7677
l1 = loss_gf(p0, train_loader.data...)[1]

0 commit comments

Comments
 (0)