From 0d61fcd4bcb33b04320f6caa8e0c47e23b527526 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Tue, 13 May 2025 07:58:07 +0200 Subject: [PATCH 1/3] type stable prediction of FluxModelApplicator --- ext/HybridVariationalInferenceFluxExt.jl | 8 ++++-- src/elbo.jl | 31 ++++++++++++------------ src/gf.jl | 4 +-- test/test_Flux.jl | 4 ++- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index cda7157..f634270 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -16,11 +16,15 @@ function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type:: FluxApplicator(rebuild), ϕ end -function HVI.apply_model(app::FluxApplicator, x, ϕ) +function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T + # assume no size informmation in x -> can hint the type of the result + # to be the same as the type of the input m = app.rebuild(ϕ) - res = m(x) + res = m(x)::T res end +function _apply_model(m,x) # function barrier so that m is inferred +end # struct FluxGPUDataHandler <: AbstractGPUDataHandler end # HVI.handle_GPU_data(::FluxGPUDataHandler, x::AbstractArray) = cpu(x) diff --git a/src/elbo.jl b/src/elbo.jl index 66c4ce4..8bd4708 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -296,7 +296,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT; xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) #Main.@infiltrate_main - μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type + μ_ζMs0 = g(xMP0, ϕg) ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc) if pbm_covar_indices isa SA.SVector{0} # do not need to predict again but just add the residuals to μ_ζP and μ_ζMs @@ -308,7 +308,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT; ζP = μ_ζP .+ rP # second pass: append ζP rather than μ_ζP to covars to xM xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) - μ_ζMst = g(xMP, ϕg)::MT # for gpu restructure returns Any, so apply type + μ_ζMst = g(xMP, ϕg) ζMs = (μ_ζMst .+ rMs)' # already transform to par-last form ζP, ζMs end @@ -356,26 +356,27 @@ function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0}, SA.SA[] end -# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N -# xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name? +# remove? +# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N +# # xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name? +# # μ_ζMs = g(xMP2, ϕg) +# # end +# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0) +# # # if pbm_covars is the empty tuple, just return the original prediction on xM only +# # # rather than calling the ML model +# # μ_ζMs0 +# # end + +# function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0) +# xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) # μ_ζMs = g(xMP2, ϕg) # end -# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0) +# function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0) # # if pbm_covars is the empty tuple, just return the original prediction on xM only # # rather than calling the ML model # μ_ζMs0 # end -function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0) - xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) - μ_ζMs = g(xMP2, ϕg) -end -function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0) - # if pbm_covars is the empty tuple, just return the original prediction on xM only - # rather than calling the ML model - μ_ζMs0 -end - """ Extract relevant parameters from ζ and return n_MC generated multivariate normal draws together with the vector of standard deviations, `σ`: `(ζP_resids, ζMs_parfirst_resids, σ)` diff --git a/src/gf.jl b/src/gf.jl index 2fefe2d..44ee7eb 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -120,10 +120,10 @@ end composition transM ∘ g: transformation after machine learning parameter prediction Provide a `transMs = StackedArray(transM, n_batch)` """ -function gtrans(g, transMs, xMP::T, ϕg; cdev) where T +function gtrans(g, transMs, xMP, ϕg; cdev) # TODO remove after removing gf # predict the log of the parameters - ζMst = g(xMP, ϕg)::T # problem of Flux model applicator restructure + ζMst = g(xMP, ϕg) ζMs = ζMst' ζMs_cpu = cdev(ζMs) θMs = transMs(ζMs_cpu) diff --git a/test/test_Flux.jl b/test/test_Flux.jl index 8cbaa07..d19ddd3 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -51,7 +51,9 @@ using Flux n_site = 3 x = rand(Float32, n_covar, n_site) |> gpu ϕ = ϕg |> gpu - y = g(x, ϕ) + y = @inferred g(x, ϕ) + # @usingany Cthulhu + # @descend_code_warntype g(x, ϕ) #@test ϕ isa GPUArraysCore.AbstractGPUArray @test size(y) == (n_out, n_site) end; From 424655763710055922aa19cb88b04496a7256fa7 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 16 May 2025 16:22:37 +0200 Subject: [PATCH 2/3] implement and test a type-stable rebuild FluxApplicator test show no significant speedup --- ext/HybridVariationalInferenceFluxExt.jl | 21 ++++++++++++++++++++- src/HybridVariationalInference.jl | 1 + src/ModelApplicator.jl | 10 ++++++++++ test/test_Flux.jl | 8 ++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index f634270..da8275d 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -10,6 +10,13 @@ struct FluxApplicator{RT} <: AbstractModelApplicator rebuild::RT end +struct PartricFluxApplicator{RT, MT, YT} <: AbstractModelApplicator + rebuild::RT +end + +const FluxApplicatorU{RT} = Union{FluxApplicator{RT},PartricFluxApplicator{RT}} where RT + + function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType) # TODO: care fore rng and float_type ϕ, rebuild = Flux.destructure(m) @@ -23,7 +30,18 @@ function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T res = m(x)::T res end -function _apply_model(m,x) # function barrier so that m is inferred + + +function HVI.construct_partric(app::FluxApplicator{RT}, x, ϕ) where RT + m = app.rebuild(ϕ) + y = m(x) + PartricFluxApplicator{RT, typeof(m), typeof(y)}(app.rebuild) +end + +function HVI.apply_model(app::PartricFluxApplicator{RT, MT, YT}, x, ϕ) where {RT, MT, YT} + m = app.rebuild(ϕ)::MT + res = m(x)::YT + res end # struct FluxGPUDataHandler <: AbstractGPUDataHandler end @@ -70,4 +88,5 @@ end + end # module diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 97d2bf5..9ea9b57 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -28,6 +28,7 @@ include("bijectors_utils.jl") export AbstractComponentArrayInterpreter, ComponentArrayInterpreter, StaticComponentArrayInterpreter export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters +export construct_partric include("ComponentArrayInterpreter.jl") export AbstractModelApplicator, construct_ChainsApplicator diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 5f83155..5c27c92 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -34,6 +34,16 @@ function apply_model(app::NullModelApplicator, x, ϕ) return x end +""" +Construct a parametric type-stable model applicator, given +covariates, `x`, and parameters, `ϕ`. + +The default retuns the current model applicator. +""" +function construct_partric(app::AbstractModelApplicator, x, ϕ) + app +end + """ construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type) diff --git a/test/test_Flux.jl b/test/test_Flux.jl index d19ddd3..0260208 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -56,6 +56,14 @@ using Flux # @descend_code_warntype g(x, ϕ) #@test ϕ isa GPUArraysCore.AbstractGPUArray @test size(y) == (n_out, n_site) + gp = construct_partric(g, x, ϕ) + y2 = @inferred gp(x, ϕ) + @test y2 == y + () -> begin + # @usingany BenchmarkTools + #@benchmark g(x,ϕ) + #@benchmark gp(x,ϕ) # no difference type-inferred + end end; @testset "cpu_ca" begin From 34ebd6de5ba25306acc71607a47057e815fca227 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 26 May 2025 09:24:37 +0200 Subject: [PATCH 3/3] typos --- src/ModelApplicator.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 5c27c92..85a1140 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -38,7 +38,7 @@ end Construct a parametric type-stable model applicator, given covariates, `x`, and parameters, `ϕ`. -The default retuns the current model applicator. +The default returns the current model applicator. """ function construct_partric(app::AbstractModelApplicator, x, ϕ) app