Skip to content

Commit 36c63be

Browse files
committed
remove MLEngine from get_hybridcase_MLapplicator
better infer it from scenario, although this looses some type-stability.
1 parent df48ad4 commit 36c63be

12 files changed

+66
-31
lines changed

dev/doubleMM.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Statistics
77
using ComponentArrays: ComponentArrays as CA
88

99
using SimpleChains
10-
import Flux # to allow for FluxMLEngine and cpu()
10+
import Flux
1111
using MLUtils
1212
import Zygote
1313

@@ -17,20 +17,22 @@ using Bijectors
1717
using UnicodePlots
1818

1919
const case = DoubleMM.DoubleMMCase()
20-
const MLengine = Val(nameof(SimpleChains))
21-
const FluxMLengine = Val(nameof(Flux))
2220
scenario = (:default,)
2321
rng = StableRNG(111)
2422

2523
par_templates = get_hybridcase_par_templates(case; scenario)
2624

27-
(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
25+
#n_covar = get_hybridcase_n_covar(case; scenario)
26+
#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
2827

2928
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
3029
) = gen_hybridcase_synthetic(rng, case; scenario);
3130

31+
n_covar = size(xM,1)
32+
33+
3234
#----- fit g to θMs_true
33-
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
35+
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
3436
(; transP, transM) = get_hybridcase_transforms(case; scenario)
3537

3638
function loss_g(ϕg, x, g, transM)
@@ -90,6 +92,8 @@ FT = get_hybridcase_float_type(case; scenario)
9092
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
9193
ϕ_true = ϕ
9294

95+
96+
9397
() -> begin
9498
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
9599
logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893)
@@ -162,7 +166,8 @@ mean_σ_o_MC = 0.006042
162166

163167
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
164168
xM_gpu = xM |> Flux.gpu;
165-
g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario);
169+
scenario_flux = (scenario..., :use_Flux)
170+
g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux);
166171

167172
# otpimize using LUX
168173
() -> begin
@@ -200,7 +205,7 @@ gr = Zygote.gradient(fcost,
200205
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)
201206

202207
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
203-
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux))
208+
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux))
204209

205210
optf = Optimization.OptimizationFunction(
206211
(ϕ, data) -> begin

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ end
3535
# HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
3636
# end
3737

38-
function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
38+
function HVI.construct_3layer_MLApplicator(
39+
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux};
3940
scenario::NTuple = ())
4041
(;θM) = get_hybridcase_par_templates(case; scenario)
4142
n_out = length(θM)
42-
n_covar = 5
43+
n_covar = get_hybridcase_n_covar(case; scenario)
4344
#(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
4445
float_type = get_hybridcase_float_type(case; scenario)
4546
is_using_dropout = :use_dropout scenario

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ end
1919

2020
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
2121

22-
function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
23-
scenario::NTuple=())
22+
function HVI.construct_3layer_MLApplicator(
23+
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains};
24+
scenario::NTuple = ())
2425
n_covar = get_hybridcase_n_covar(case; scenario)
2526
FloatType = get_hybridcase_float_type(case; scenario)
2627
(;θM) = get_hybridcase_par_templates(case; scenario)

src/DoubleMM/f_doubleMM.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase;
9191
)
9292
end
9393

94+
function HVI.get_hybridcase_MLapplicator(
95+
rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase; scenario = ())
96+
ml_engine = select_ml_engine(; scenario)
97+
construct_3layer_MLApplicator(rng, case, ml_engine; scenario)
98+
end
99+
100+
94101

95102

96103

src/HybridProblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ())
5454
prob.f
5555
end
5656

57-
function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ());
57+
function get_hybridcase_MLapplicator(prob::HybridProblem; scenario::NTuple = ());
5858
prob.g, prob.ϕg
5959
end
6060

src/HybridVariationalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export ComponentArrayInterpreter, flatten1, get_concrete
1717
include("ComponentArrayInterpreter.jl")
1818

1919
export AbstractModelApplicator, construct_ChainsApplicator
20+
export construct_3layer_MLApplicator, select_ml_engine
2021
include("ModelApplicator.jl")
2122

2223
export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler

src/ModelApplicator.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,31 @@ end
3636
# function construct_LuxApplicator end
3737

3838

39+
"""
40+
construct_3layer_MLApplicator(
41+
rng::AbstractRNG, case::HVI.AbstractHybridCase, <ml_engine>;
42+
scenario::NTuple = ())
43+
44+
`ml_engine` usually is of type `Val{Symbol}`, e.g. Val(:Flux). See `select_ml_engine`.
45+
"""
46+
function construct_3layer_MLApplicator end
47+
48+
"""
49+
select_ml_engine(;scenario)
50+
51+
Returns a value type `Val{:Symbol}` to dispatch on the machine learning engine to use.
52+
- defaults to `Val(:SimpleChains)`
53+
- `:use_Lux ∈ scenario -> Val(:Lux)`
54+
- `:use_Flux ∈ scenario -> Val(:Flux)`
55+
"""
56+
function select_ml_engine(;scenario)
57+
if :use_Lux scenario
58+
return Val(:Lux)
59+
elseif :use_Flux scenario
60+
return Val(:Flux)
61+
else
62+
# default
63+
return Val(:SimpleChains)
64+
end
65+
end
66+

src/hybrid_case.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,19 @@ abstract type AbstractHybridCase end;
1919

2020

2121
"""
22-
get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase, MLEngine; scenario=())
22+
get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase; scenario=())
2323
2424
Construct the machine learning model fro given problem case and ML-Framework and
2525
scenario.
2626
27-
The MLEngine is a value type of a Symbol, usually the name of the module, e.g.
28-
`const MLengine = Val(nameof(SimpleChains))`.
29-
3027
returns a Tuple of
3128
- AbstractModelApplicator
3229
- initial parameter vector
3330
"""
3431
function get_hybridcase_MLapplicator end
3532

36-
function get_hybridcase_MLapplicator(case::AbstractHybridCase, MLEngine; scenario=())
37-
get_hybridcase_MLapplicator(Random.default_rng(), case, MLEngine; scenario)
33+
function get_hybridcase_MLapplicator(case::AbstractHybridCase; scenario=())
34+
get_hybridcase_MLapplicator(Random.default_rng(), case; scenario)
3835
end
3936

4037
"""
@@ -138,7 +135,7 @@ function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridC
138135
scenario = ())
139136
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario)
140137
n_batch = 10
141-
xM_gpu = :use_flux scenario ? CuArray(xM) : xM
138+
xM_gpu = :use_Flux scenario ? CuArray(xM) : xM
142139
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
143140
return(train_loader)
144141
end

test/test_HybridProblem.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import Zygote
1212

1313
using OptimizationOptimisers
1414

15-
const MLengine = Val(nameof(SimpleChains))
16-
1715
construct_problem = () -> begin
1816
FT = Float32
1917
θP = CA.ComponentVector{FT}(r0=0.3, K2=2.0)
@@ -63,7 +61,7 @@ scenario = (:default,)
6361
@testset "loss_gf" begin
6462
#----------- fit g and θP to y_o
6563
rng = StableRNG(111)
66-
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario)
64+
g, ϕg0 = get_hybridcase_MLapplicator(prob; scenario)
6765
train_loader = get_hybridcase_train_dataloader(rng, prob; scenario)
6866
(xM, xP, y_o, y_unc) = first(train_loader)
6967
f = get_hybridcase_PBmodel(prob; scenario)
@@ -101,7 +99,7 @@ import Flux
10199

102100
@testset "neg_elbo_transnorm_gf cpu" begin
103101
rng = StableRNG(111)
104-
g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine)
102+
g, ϕg0 = get_hybridcase_MLapplicator(prob)
105103
train_loader = get_hybridcase_train_dataloader(rng, prob)
106104
(xM, xP, y_o, y_unc) = first(train_loader)
107105
n_batch = size(y_o, 2)

test/test_doubleMM.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import Zygote
1111

1212
using OptimizationOptimisers
1313

14-
const MLengine = Val(nameof(SimpleChains))
1514
const case = DoubleMM.DoubleMMCase()
1615
scenario = (:default,)
1716

@@ -34,7 +33,7 @@ rng = StableRNG(111)
3433
end
3534

3635
@testset "loss_g" begin
37-
g, ϕg0 = get_hybridcase_MLapplicator(rng, case, MLengine; scenario);
36+
g, ϕg0 = get_hybridcase_MLapplicator(rng, case; scenario);
3837
(;transP, transM) = get_hybridcase_transforms(case; scenario)
3938

4039
function loss_g(ϕg, x, g, transM)
@@ -67,7 +66,7 @@ end
6766

6867
@testset "loss_gf" begin
6968
#----------- fit g and θP to y_o (without uncertainty, without transforming θP)
70-
g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario);
69+
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
7170
(;transP, transM) = get_hybridcase_transforms(case; scenario)
7271
f = get_hybridcase_PBmodel(case; scenario)
7372

0 commit comments

Comments
 (0)