Skip to content

Commit 3086e10

Browse files
committed
HybridProblem constructors with SimpleChain, Flux.Chain and Lux.Chain
1 parent bfce57b commit 3086e10

9 files changed

+55
-36
lines changed

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ struct FluxApplicator{RT} <: AbstractModelApplicator
99
end
1010

1111
function HVI.construct_FluxApplicator(m::Chain)
12-
_, rebuild = destructure(m)
13-
FluxApplicator(rebuild)
12+
ϕ, rebuild = destructure(m)
13+
FluxApplicator(rebuild), ϕ
1414
end
1515

1616
function HVI.apply_model(app::FluxApplicator, x, ϕ)
@@ -29,8 +29,7 @@ end
2929
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain,
3030
args...; kwargs...)
3131
# constructor with Flux.Chain
32-
ϕ, _ = destructure(g_chain)
33-
g = construct_FluxApplicator(g_chain), ϕ
32+
g, ϕg = construct_FluxApplicator(g_chain)
3433
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
3534
end
3635

@@ -48,8 +47,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
4847
# dense layer without bias that maps to n outputs and `identity` activation
4948
Flux.Dense(n_covar * 4 => n_out, identity, bias = false)
5049
)
51-
ϕ, _ = destructure(g_chain)
52-
construct_FluxApplicator(g_chain), ϕ
50+
construct_FluxApplicator(g_chain)
5351
end
5452

5553

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator
1010
int_ϕ::IT
1111
end
1212

13-
function HVI.construct_LuxApplicator(m::Chain; device = gpu_device())
13+
function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device())
1414
ps, st = Lux.setup(Random.default_rng(), m)
15-
ps_ca = CA.ComponentArray(ps)
15+
ps_ca = float_type.(CA.ComponentArray(ps))
1616
st = st |> device
1717
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
1818
#stateful_layer(x_o_gpu[:, 1:n_site_batch], ps_ca)
1919
int_ϕ = get_concrete(ComponentArrayInterpreter(ps_ca))
20-
LuxApplicator(stateful_layer, int_ϕ)
20+
LuxApplicator(stateful_layer, int_ϕ), ps_ca
2121
end
2222

2323
function HVI.apply_model(app::LuxApplicator, x, ϕ)
@@ -26,11 +26,9 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ)
2626
end
2727

2828
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain,
29-
args...; kwargs...)
29+
args...; device = gpu_device(), kwargs...)
3030
# constructor with SimpleChain
31-
g = construct_LuxApplicator(g_chain)
32-
FT = eltype(θM)
33-
ϕg = randn(FT, length(g.int_ϕ))
31+
g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device)
3432
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
3533
end
3634

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@ struct SimpleChainsApplicator{MT} <: AbstractModelApplicator
1111
m::MT
1212
end
1313

14-
HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
14+
function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32)
15+
ϕ = SimpleChains.init_params(m, FloatType);
16+
SimpleChainsApplicator(m), ϕ
17+
end
1518

1619
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1720

1821
function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain,
1922
args...; kwargs...)
2023
# constructor with SimpleChain
21-
g = construct_SimpleChainsApplicator(g_chain)
22-
ϕg = SimpleChains.init_params(g_chain, eltype(θM))
24+
g, ϕg = construct_SimpleChainsApplicator(g_chain)
2325
HybridProblem(θP, θM, g, ϕg, args...; kwargs...)
2426
end
2527

@@ -50,8 +52,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{
5052
TurboDense{false}(identity, n_out)
5153
)
5254
end
53-
ϕ = SimpleChains.init_params(g_chain, FloatType);
54-
SimpleChainsApplicator(g_chain), ϕ
55+
construct_SimpleChainsApplicator(g_chain, FloatType)
5556
end
5657

5758
end # module

src/ModelApplicator.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
"""
2+
AbstractModelApplicator(x, ϕ)
3+
4+
Abstraction of applying a machine learning model at covariate matrix, `x`,
5+
using parameters, `ϕ`. It returns a matrix of predictions with the same
6+
number of rows as in `x`.
7+
8+
Constructors for specifics are defined in extension packages.
9+
Each constructor takes a special type of machine learning model and returns
10+
a tuple with two components:
11+
- The applicator
12+
- a sample parameter vector (type depends on the used ML-framework)
13+
14+
Implemented are
15+
- `construct_SimpleChainsApplicator`
16+
- `construct_FluxApplicator`
17+
- `construct_LuxApplicator`
18+
"""
119
abstract type AbstractModelApplicator end
220

321
function apply_model end

test/test_Flux.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,19 @@ end;
3535
Dense(n_covar * 4 => n_covar * 4, tanh),
3636
Dense(n_covar * 4 => n_out, identity, bias=false),
3737
)
38-
g = construct_FluxApplicator(g_chain)
38+
g, ϕg = construct_FluxApplicator(g_chain |> f64)
39+
@test eltype(ϕg) == Float64
40+
g, ϕg = construct_FluxApplicator(g_chain)
41+
@test eltype(ϕg) == Float32
3942
n_site = 3
4043
x = rand(Float32, n_covar, n_site)
41-
ϕ, _rebuild = destructure(g_chain)
42-
y = g(x, ϕ)
44+
#ϕ, _rebuild = destructure(g_chain)
45+
y = g(x, ϕg)
4346
@test size(y) == (n_out, n_site)
4447
#
4548
n_site = 3
4649
x = rand(Float32, n_covar, n_site) |> gpu
47-
ϕ = ϕ |> gpu
50+
ϕ = ϕg |> gpu
4851
y = g(x, ϕ)
4952
#@test ϕ isa GPUArraysCore.AbstractGPUArray
5053
@test size(y) == (n_out, n_site)

test/test_HybridProblem.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ construct_problem = () -> begin
4444
# dense layer without bias that maps to n outputs and `identity` activation
4545
TurboDense{false}(identity, n_out)
4646
)
47-
# g = construct_SimpleChainsApplicator(g_chain)
48-
# ϕg = SimpleChains.init_params(g_chain, eltype(θM))
47+
# g, ϕg = construct_SimpleChainsApplicator(g_chain)
4948
#
5049
rng = StableRNG(111)
5150
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o

test/test_Lux.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using HybridVariationalInference
22
using Test
3+
using CUDA, GPUArraysCore
34
using Lux
45
using StatsFuns: logistic
5-
using CUDA, GPUArraysCore
66

77

88
@testset "LuxModelApplicator" begin
@@ -13,18 +13,20 @@ using CUDA, GPUArraysCore
1313
Dense(n_covar * 4 => n_covar * 4, tanh),
1414
Dense(n_covar * 4 => n_out, logistic, use_bias=false),
1515
);
16-
g = construct_LuxApplicator(g_chain; device = cpu_device());
16+
g, ϕ = construct_LuxApplicator(g_chain, Float64; device = cpu_device());
17+
@test eltype(ϕ) == Float64
18+
g, ϕ = construct_LuxApplicator(g_chain; device = cpu_device());
19+
@test eltype(ϕ) == Float32
1720
n_site = 3
1821
x = rand(Float32, n_covar, n_site)
19-
ϕ = randn(Float32, Lux.parameterlength(g_chain))
22+
#ϕ = randn(Float32, Lux.parameterlength(g_chain))
2023
y = g(x, ϕ)
2124
@test size(y) == (n_out, n_site)
2225
#
23-
g = construct_LuxApplicator(g_chain; device = gpu_device());
24-
n_site = 3
2526
x = rand(Float32, n_covar, n_site) |> gpu_device()
26-
ϕ = randn(Float32, Lux.parameterlength(g_chain)) |> gpu_device()
27-
y = g(x, ϕ)
27+
ϕ_gpu = ϕ |> gpu_device()
28+
#ϕ = randn(Float32, Lux.parameterlength(g_chain)) |> gpu_device()
29+
y = g(x, ϕ_gpu)
2830
#@test ϕ isa GPUArraysCore.AbstractGPUArray
2931
@test size(y) == (n_out, n_site)
3032
end;

test/test_SimpleChains.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ using StatsFuns: logistic
1212
TurboDense{true}(tanh, n_covar * 4),
1313
TurboDense{false}(logistic, n_out)
1414
)
15-
g = construct_SimpleChainsApplicator(g_chain)
15+
g, ϕg = construct_SimpleChainsApplicator(g_chain)
1616
n_site = 3
1717
x = rand(n_covar, n_site)
18-
ϕ = SimpleChains.init_params(g_chain);
19-
y = g(x, ϕ)
18+
#ϕg = SimpleChains.init_params(g_chain);
19+
y = g(x, ϕg)
2020
@test size(y) == (n_out, n_site)
2121
end;

test/test_cholesky_structure.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ end
247247
#@test Upred ≈ CU
248248
SUpred = Upred *
249249
#hcat(SUpred, SU)
250-
@test SUpredSU atol=2e-1
250+
@test SUpredSU atol=6e-1
251251
S_pred =' * Upred' * Upred *
252-
@test S_predS atol=2e-1
252+
@test S_predS atol=6e-1
253253
end
254254

0 commit comments

Comments
 (0)