Skip to content

Commit bfce57b

Browse files
committed
HybirdProblem constructors with SimpleChain, Flux.Chain and Lux.Chain
1 parent 1ccfa4a commit bfce57b

File tree

5 files changed

+41
-7
lines changed

5 files changed

+41
-7
lines changed

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module HybridVariationalInferenceFluxExt
22

33
using HybridVariationalInference, Flux
44
using HybridVariationalInference: HybridVariationalInference as HVI
5+
using ComponentArrays: ComponentArrays as CA
56

67
struct FluxApplicator{RT} <: AbstractModelApplicator
78
rebuild::RT
@@ -25,6 +26,14 @@ function __init__()
2526
HVI.set_default_GPUHandler(FluxGPUDataHandler())
2627
end
2728

29+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
30+
args...; kwargs...)
31+
# constructor with Flux.Chain
32+
ϕ, _ = destructure(g_chain)
33+
g = construct_FluxApplicator(g_chain), ϕ
34+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
35+
end
36+
2837
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux};
2938
scenario::NTuple = ())
3039
(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
@@ -43,4 +52,6 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
4352
construct_FluxApplicator(g_chain), ϕ
4453
end
4554

55+
56+
4657
end # module

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,13 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ)
2525
app.stateful_layer(x, ϕc)
2626
end
2727

28+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
29+
args...; kwargs...)
30+
# constructor with SimpleChain
31+
g = construct_LuxApplicator(g_chain)
32+
FT = eltype(θM)
33+
ϕg = randn(FT, length(g.int_ϕ))
34+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
35+
end
36+
2837
end # module

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ module HybridVariationalInferenceSimpleChainsExt
33
using HybridVariationalInference, SimpleChains
44
using HybridVariationalInference: HybridVariationalInference as HVI
55
using StatsFuns: logistic
6+
using ComponentArrays: ComponentArrays as CA
7+
8+
69

710
struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
811
m::MT
@@ -12,6 +15,14 @@ HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
1215

1316
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1417

18+
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
19+
args...; kwargs...)
20+
# constructor with SimpleChain
21+
g = construct_SimpleChainsApplicator(g_chain)
22+
ϕg = SimpleChains.init_params(g_chain, eltype(θM))
23+
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
24+
end
25+
1526
function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
1627
scenario::NTuple=())
1728
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)

src/HybridProblem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ struct HybridProblem <: AbstractHybridCase
1212
# inner constructor to constrain the types
1313
function HybridProblem(
1414
θP::CA.ComponentVector, θM::CA.ComponentVector,
15+
g::AbstractModelApplicator, ϕg,
16+
f::Function,
1517
transM::Union{Function, Bijectors.Transform},
1618
transP::Union{Function, Bijectors.Transform},
1719
n_covar::Integer, n_batch::Integer,
18-
f::Function, g::AbstractModelApplicator, ϕg, train_loader::DataLoader)
20+
train_loader::DataLoader)
1921
new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader)
2022
end
2123
end

test/test_HybridProblem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,21 @@ construct_problem = () -> begin
4444
# dense layer without bias that maps to n outputs and `identity` activation
4545
TurboDense{false}(identity, n_out)
4646
)
47-
g = construct_SimpleChainsApplicator(g_chain)
48-
ϕg = SimpleChains.init_params(g_chain, eltype(θM))
47+
# g = construct_SimpleChainsApplicator(g_chain)
48+
# ϕg = SimpleChains.init_params(g_chain, eltype(θM))
4949
#
5050
rng = StableRNG(111)
5151
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
52-
) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng;);
52+
) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng;)
5353
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
54-
HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global,
55-
g, ϕg, train_loader)
54+
# HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global,
55+
# g, ϕg, train_loader)
56+
HybridProblem(θP, θM, g_chain, f_doubleMM_with_global,
57+
transM, transP, n_covar, n_batch, train_loader)
5658
end
5759
prob = construct_problem();
5860
scenario = (:default,)
5961

60-
6162
#(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario)
6263

6364
@testset "loss_gf" begin

0 commit comments

Comments
 (0)