Skip to content

Commit 4246557

Browse files
committed
implement and test a type-stable rebuild FluxApplicator
test show no significant speedup
1 parent 0d61fcd commit 4246557

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ struct FluxApplicator{RT} <: AbstractModelApplicator
1010
rebuild::RT
1111
end
1212

13+
struct PartricFluxApplicator{RT, MT, YT} <: AbstractModelApplicator
14+
rebuild::RT
15+
end
16+
17+
const FluxApplicatorU{RT} = Union{FluxApplicator{RT},PartricFluxApplicator{RT}} where RT
18+
19+
1320
function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
1421
# TODO: care fore rng and float_type
1522
ϕ, rebuild = Flux.destructure(m)
@@ -23,7 +30,18 @@ function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T
2330
res = m(x)::T
2431
res
2532
end
26-
function _apply_model(m,x) # function barrier so that m is inferred
33+
34+
35+
function HVI.construct_partric(app::FluxApplicator{RT}, x, ϕ) where RT
36+
m = app.rebuild(ϕ)
37+
y = m(x)
38+
PartricFluxApplicator{RT, typeof(m), typeof(y)}(app.rebuild)
39+
end
40+
41+
function HVI.apply_model(app::PartricFluxApplicator{RT, MT, YT}, x, ϕ) where {RT, MT, YT}
42+
m = app.rebuild(ϕ)::MT
43+
res = m(x)::YT
44+
res
2745
end
2846

2947
# struct FluxGPUDataHandler <: AbstractGPUDataHandler end
@@ -70,4 +88,5 @@ end
7088

7189

7290

91+
7392
end # module

src/HybridVariationalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("bijectors_utils.jl")
2828
export AbstractComponentArrayInterpreter, ComponentArrayInterpreter,
2929
StaticComponentArrayInterpreter
3030
export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters
31+
export construct_partric
3132
include("ComponentArrayInterpreter.jl")
3233

3334
export AbstractModelApplicator, construct_ChainsApplicator

src/ModelApplicator.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ function apply_model(app::NullModelApplicator, x, ϕ)
3434
return x
3535
end
3636

37+
"""
38+
Construct a parametric type-stable model applicator, given
39+
covariates, `x`, and parameters, `ϕ`.
40+
41+
The default retuns the current model applicator.
42+
"""
43+
function construct_partric(app::AbstractModelApplicator, x, ϕ)
44+
app
45+
end
46+
3747

3848
"""
3949
construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type)

test/test_Flux.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ using Flux
5656
# @descend_code_warntype g(x, ϕ)
5757
#@test ϕ isa GPUArraysCore.AbstractGPUArray
5858
@test size(y) == (n_out, n_site)
59+
gp = construct_partric(g, x, ϕ)
60+
y2 = @inferred gp(x, ϕ)
61+
@test y2 == y
62+
() -> begin
63+
# @usingany BenchmarkTools
64+
#@benchmark g(x,ϕ)
65+
#@benchmark gp(x,ϕ) # no difference type-inferred
66+
end
5967
end;
6068

6169
@testset "cpu_ca" begin

0 commit comments

Comments
 (0)