Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,37 @@ struct FluxApplicator{RT} <: AbstractModelApplicator
rebuild::RT
end

struct PartricFluxApplicator{RT, MT, YT} <: AbstractModelApplicator
rebuild::RT
end

const FluxApplicatorU{RT} = Union{FluxApplicator{RT},PartricFluxApplicator{RT}} where RT


function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType)
# TODO: care fore rng and float_type
ϕ, rebuild = Flux.destructure(m)
FluxApplicator(rebuild), ϕ
end

function HVI.apply_model(app::FluxApplicator, x, ϕ)
function HVI.apply_model(app::FluxApplicator, x::T, ϕ) where T
# assume no size informmation in x -> can hint the type of the result
# to be the same as the type of the input
m = app.rebuild(ϕ)
res = m(x)
res = m(x)::T
res
end


function HVI.construct_partric(app::FluxApplicator{RT}, x, ϕ) where RT
m = app.rebuild(ϕ)
y = m(x)
PartricFluxApplicator{RT, typeof(m), typeof(y)}(app.rebuild)
end

function HVI.apply_model(app::PartricFluxApplicator{RT, MT, YT}, x, ϕ) where {RT, MT, YT}
m = app.rebuild(ϕ)::MT
res = m(x)::YT
res
end

Expand Down Expand Up @@ -66,4 +88,5 @@ end




end # module
1 change: 1 addition & 0 deletions src/HybridVariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include("bijectors_utils.jl")
export AbstractComponentArrayInterpreter, ComponentArrayInterpreter,
StaticComponentArrayInterpreter
export flatten1, get_concrete, get_positions, stack_ca_int, compose_interpreters
export construct_partric
include("ComponentArrayInterpreter.jl")

export AbstractModelApplicator, construct_ChainsApplicator
Expand Down
10 changes: 10 additions & 0 deletions src/ModelApplicator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ function apply_model(app::NullModelApplicator, x, ϕ)
return x
end

"""
Construct a parametric type-stable model applicator, given
covariates, `x`, and parameters, `ϕ`.

The default returns the current model applicator.
"""
function construct_partric(app::AbstractModelApplicator, x, ϕ)
app
end


"""
construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type)
Expand Down
31 changes: 16 additions & 15 deletions src/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
xMP0 = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices)
#Main.@infiltrate_main

μ_ζMs0 = g(xMP0, ϕg)::MT # for gpu restructure returns Any, so apply type
μ_ζMs0 = g(xMP0, ϕg)
ζP_resids, ζMs_parfirst_resids, σ = sample_ζresid_norm(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_ends, int_unc)
if pbm_covar_indices isa SA.SVector{0}
# do not need to predict again but just add the residuals to μ_ζP and μ_ζMs
Expand All @@ -308,7 +308,7 @@ function generate_ζ(rng, g, ϕ::AbstractVector{FT}, xM::MT;
ζP = μ_ζP .+ rP
# second pass: append ζP rather than μ_ζP to covars to xM
xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
μ_ζMst = g(xMP, ϕg)::MT # for gpu restructure returns Any, so apply type
μ_ζMst = g(xMP, ϕg)
ζMs = (μ_ζMst .+ rMs)' # already transform to par-last form
ζP, ζMs
end
Expand Down Expand Up @@ -356,26 +356,27 @@ function get_pbm_covar_indices(ζP, pbm_covars::NTuple{0},
SA.SA[]
end

# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N
# xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name?
# remove?
# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{N,Symbol}, g, ϕg, μ_ζMs0) where N
# # xMP2 = _append_PBM_covars(xM, ζP, pbm_covars) # need different variable name?
# # μ_ζMs = g(xMP2, ϕg)
# # end
# # function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0)
# # # if pbm_covars is the empty tuple, just return the original prediction on xM only
# # # rather than calling the ML model
# # μ_ζMs0
# # end

# function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0)
# xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
# μ_ζMs = g(xMP2, ϕg)
# end
# function _predict_μ_ζMs(xM, ζP, pbm_covars::NTuple{0}, g, ϕg, μ_ζMs0)
# function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0)
# # if pbm_covars is the empty tuple, just return the original prediction on xM only
# # rather than calling the ML model
# μ_ζMs0
# end

function _predict_μ_ζMs(xM, ζP, pbm_covar_indices::AbstractVector, g, ϕg, μ_ζMs0)
xMP2 = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
μ_ζMs = g(xMP2, ϕg)
end
function _predict_μ_ζMs(xM, ζP, pbm_covars_indices::SA.StaticVector{0}, g, ϕg, μ_ζMs0)
# if pbm_covars is the empty tuple, just return the original prediction on xM only
# rather than calling the ML model
μ_ζMs0
end

"""
Extract relevant parameters from ζ and return n_MC generated multivariate normal draws
together with the vector of standard deviations, `σ`: `(ζP_resids, ζMs_parfirst_resids, σ)`
Expand Down
4 changes: 2 additions & 2 deletions src/gf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ end
composition transM ∘ g: transformation after machine learning parameter prediction
Provide a `transMs = StackedArray(transM, n_batch)`
"""
function gtrans(g, transMs, xMP::T, ϕg; cdev) where T
function gtrans(g, transMs, xMP, ϕg; cdev)
# TODO remove after removing gf
# predict the log of the parameters
ζMst = g(xMP, ϕg)::T # problem of Flux model applicator restructure
ζMst = g(xMP, ϕg)
ζMs = ζMst'
ζMs_cpu = cdev(ζMs)
θMs = transMs(ζMs_cpu)
Expand Down
12 changes: 11 additions & 1 deletion test/test_Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,19 @@ using Flux
n_site = 3
x = rand(Float32, n_covar, n_site) |> gpu
ϕ = ϕg |> gpu
y = g(x, ϕ)
y = @inferred g(x, ϕ)
# @usingany Cthulhu
# @descend_code_warntype g(x, ϕ)
#@test ϕ isa GPUArraysCore.AbstractGPUArray
@test size(y) == (n_out, n_site)
gp = construct_partric(g, x, ϕ)
y2 = @inferred gp(x, ϕ)
@test y2 == y
() -> begin
# @usingany BenchmarkTools
#@benchmark g(x,ϕ)
#@benchmark gp(x,ϕ) # no difference type-inferred
end
end;

@testset "cpu_ca" begin
Expand Down
Loading