diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index cda7157..da8275d 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -10,15 +10,37 @@ struct FluxApplicator{RT} <: AbstractModelApplicator rebuild::RT end +struct PartricFluxApplicator{RT, MT, YT} <: AbstractModelApplicator + rebuild::RT +end + +const FluxApplicatorU{RT} = Union{FluxApplicator{RT},PartricFluxApplicator{RT}} where RT + + function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType) # TODO: care fore rng and float_type ϕ, rebuild = Flux.destructure(m) FluxApplicator(rebuild), ϕ end -function HVI.apply_model(app::FluxApplicator, x, ϕ) +function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T + # assume no size informmation in x -> can hint the type of the result + # to be the same as the type of the input m = app.rebuild(ϕ) - res = m(x) + res = m(x)::T + res +end + + +function HVI.construct_partric(app::FluxApplicator{RT}, x, ϕ) where RT + m = app.rebuild(ϕ) + y = m(x) + PartricFluxApplicator{RT, typeof(m), typeof(y)}(app.rebuild) +end + +function HVI.apply_model(app::PartricFluxApplicator{RT, MT, YT}, x, ϕ) where {RT, MT, YT} + m = app.rebuild(ϕ)::MT + res = m(x)::YT res end @@ -66,4 +88,5 @@ end + end # module diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 97d2bf5..9ea9b57 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -28,6 +28,7 @@ include("bijectors_utils.jl") export AbstractComponentArrayInterpreter, ComponentArrayInterpreter, StaticComponentArrayInterpreter export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters +export construct_partric include("ComponentArrayInterpreter.jl") export AbstractModelApplicator, construct_ChainsApplicator diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 5f83155..85a1140 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -34,6 +34,16 @@ function apply_model(app::NullModelApplicator, x, ϕ) return x end +""" +Construct a parametric type-stable model applicator, given +covariates, `x`, and parameters, `ϕ`. + +The default returns the current model applicator. +""" +function construct_partric(app::AbstractModelApplicator, x, ϕ) + app +end + """ construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type) diff --git a/src/elbo.jl b/src/elbo.jl index 66c4ce4..8bd4708 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -296,7 +296,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT; xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) #Main.@infiltrate_main - μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type + μ_ζMs0 = g(xMP0, ϕg) ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc) if pbm_covar_indices isa SA.SVector{0} # do not need to predict again but just add the residuals to μ_ζP and μ_ζMs @@ -308,7 +308,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT; ζP = μ_ζP .+ rP # second pass: append ζP rather than μ_ζP to covars to xM xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) - μ_ζMst = g(xMP, ϕg)::MT # for gpu restructure returns Any, so apply type + μ_ζMst = g(xMP, ϕg) ζMs = (μ_ζMst .+ rMs)' # already transform to par-last form ζP, ζMs end @@ -356,26 +356,27 @@ function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0}, SA.SA[] end -# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N -# xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name? +# remove? +# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N +# # xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name? +# # μ_ζMs = g(xMP2, ϕg) +# # end +# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0) +# # # if pbm_covars is the empty tuple, just return the original prediction on xM only +# # # rather than calling the ML model +# # μ_ζMs0 +# # end + +# function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0) +# xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) # μ_ζMs = g(xMP2, ϕg) # end -# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0) +# function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0) # # if pbm_covars is the empty tuple, just return the original prediction on xM only # # rather than calling the ML model # μ_ζMs0 # end -function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0) - xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices) - μ_ζMs = g(xMP2, ϕg) -end -function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0) - # if pbm_covars is the empty tuple, just return the original prediction on xM only - # rather than calling the ML model - μ_ζMs0 -end - """ Extract relevant parameters from ζ and return n_MC generated multivariate normal draws together with the vector of standard deviations, `σ`: `(ζP_resids, ζMs_parfirst_resids, σ)` diff --git a/src/gf.jl b/src/gf.jl index 2fefe2d..44ee7eb 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -120,10 +120,10 @@ end composition transM ∘ g: transformation after machine learning parameter prediction Provide a `transMs = StackedArray(transM, n_batch)` """ -function gtrans(g, transMs, xMP::T, ϕg; cdev) where T +function gtrans(g, transMs, xMP, ϕg; cdev) # TODO remove after removing gf # predict the log of the parameters - ζMst = g(xMP, ϕg)::T # problem of Flux model applicator restructure + ζMst = g(xMP, ϕg) ζMs = ζMst' ζMs_cpu = cdev(ζMs) θMs = transMs(ζMs_cpu) diff --git a/test/test_Flux.jl b/test/test_Flux.jl index 8cbaa07..0260208 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -51,9 +51,19 @@ using Flux n_site = 3 x = rand(Float32, n_covar, n_site) |> gpu ϕ = ϕg |> gpu - y = g(x, ϕ) + y = @inferred g(x, ϕ) + # @usingany Cthulhu + # @descend_code_warntype g(x, ϕ) #@test ϕ isa GPUArraysCore.AbstractGPUArray @test size(y) == (n_out, n_site) + gp = construct_partric(g, x, ϕ) + y2 = @inferred gp(x, ϕ) + @test y2 == y + () -> begin + # @usingany BenchmarkTools + #@benchmark g(x,ϕ) + #@benchmark gp(x,ϕ) # no difference type-inferred + end end; @testset "cpu_ca" begin