diff --git a/Project.toml b/Project.toml index 9479100..74659e6 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" @@ -38,6 +39,7 @@ Flux = "v0.15.2, 0.16" GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10.0" Lux = "1.4.2" +MLUtils = "0.4.5" Random = "1.10.0" SimpleChains = "0.4" StatsBase = "0.34.4" diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 9061e27..5cec624 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -12,8 +12,8 @@ using MLUtils import Zygote using CUDA -using TransformVariables using OptimizationOptimisers +using Bijectors using UnicodePlots const case = DoubleMM.DoubleMMCase() @@ -24,13 +24,13 @@ rng = StableRNG(111) par_templates = get_hybridcase_par_templates(case; scenario) -(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o ) = gen_hybridcase_synthetic(case, rng; scenario); #----- fit g to θMs_true -g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); +g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); function loss_g(ϕg, x, g) ζMs = g(x, ϕg) # predict the log of the parameters @@ -51,7 +51,7 @@ loss_g(ϕg_opt1, xM, g) scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) @test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9 -f = gen_hybridcase_PBmodel(case; scenario) +f = get_hybridcase_PBmodel(case; scenario) #----------- fit g and θP to y_o () -> begin @@ -84,6 +84,9 @@ end #---------- HVI logσ2y = 2 .* log.(σ_o) n_MC = 3 +transP = elementwise(exp) +transM = Stacked(elementwise(identity), elementwise(exp)) + (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊); ϕ_true = ϕ @@ -188,7 +191,7 @@ end ϕ = ϕ_ini |> Flux.gpu; xM_gpu = xM |> Flux.gpu; -g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario); +g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario); # otpimize using LUX () -> begin @@ -224,7 +227,8 @@ gr = Zygote.gradient(fcost, CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch])); gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...) -train_loader = MLUtils.DataLoader((xM_gpu, y_o), batchsize = n_batch) +train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch) +train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux)) optf = Optimization.OptimizationFunction( (ϕ, data) -> begin diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index 61b7095..1d639bb 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -2,14 +2,15 @@ module HybridVariationalInferenceFluxExt using HybridVariationalInference, Flux using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA struct FluxApplicator{RT} <: AbstractModelApplicator rebuild::RT end function HVI.construct_FluxApplicator(m::Chain) - _, rebuild = destructure(m) - FluxApplicator(rebuild) + ϕ, rebuild = destructure(m) + FluxApplicator(rebuild), ϕ end function HVI.apply_model(app::FluxApplicator, x, ϕ) @@ -25,7 +26,14 @@ function __init__() HVI.set_default_GPUHandler(FluxGPUDataHandler()) end -function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; +function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain, + args...; kwargs...) + # constructor with Flux.Chain + g, ϕg = construct_FluxApplicator(g_chain) + HybridProblem(θP, θM, g, ϕg, args...; kwargs...) +end + +function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; scenario::NTuple = ()) (; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) FloatType = get_hybridcase_FloatType(case; scenario) @@ -39,8 +47,9 @@ function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{ # dense layer without bias that maps to n outputs and `identity` activation Flux.Dense(n_covar * 4 => n_out, identity, bias = false) ) - ϕ, _ = destructure(g_chain) - construct_FluxApplicator(g_chain), ϕ + construct_FluxApplicator(g_chain) end + + end # module diff --git a/ext/HybridVariationalInferenceLuxExt.jl b/ext/HybridVariationalInferenceLuxExt.jl index bb34158..bfcf6cb 100644 --- a/ext/HybridVariationalInferenceLuxExt.jl +++ b/ext/HybridVariationalInferenceLuxExt.jl @@ -10,14 +10,14 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator int_ϕ::IT end -function HVI.construct_LuxApplicator(m::Chain; device = gpu_device()) +function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device()) ps, st = Lux.setup(Random.default_rng(), m) - ps_ca = CA.ComponentArray(ps) + ps_ca = float_type.(CA.ComponentArray(ps)) st = st |> device stateful_layer = StatefulLuxLayer{true}(m, nothing, st) #stateful_layer(x_o_gpu[:, 1:n_site_batch], ps_ca) int_ϕ = get_concrete(ComponentArrayInterpreter(ps_ca)) - LuxApplicator(stateful_layer, int_ϕ) + LuxApplicator(stateful_layer, int_ϕ), ps_ca end function HVI.apply_model(app::LuxApplicator, x, ϕ) @@ -25,4 +25,11 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ) app.stateful_layer(x, ϕc) end +function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain, + args...; device = gpu_device(), kwargs...) + # constructor with SimpleChain + g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device) + HybridProblem(θP, θM, g, ϕg, args...; kwargs...) +end + end # module diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index 520be53..f95caa9 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -3,16 +3,29 @@ module HybridVariationalInferenceSimpleChainsExt using HybridVariationalInference, SimpleChains using HybridVariationalInference: HybridVariationalInference as HVI using StatsFuns: logistic +using ComponentArrays: ComponentArrays as CA + + struct SimpleChainsApplicator{MT} <: AbstractModelApplicator m::MT end -HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m) +function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32) + ϕ = SimpleChains.init_params(m, FloatType); + SimpleChainsApplicator(m), ϕ +end HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) -function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; +function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain, + args...; kwargs...) + # constructor with SimpleChain + g, ϕg = construct_SimpleChainsApplicator(g_chain) + HybridProblem(θP, θM, g, ϕg, args...; kwargs...) +end + +function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; scenario::NTuple=()) (;n_covar, n_θM) = get_hybridcase_sizes(case; scenario) FloatType = get_hybridcase_FloatType(case; scenario) @@ -39,8 +52,7 @@ function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{ TurboDense{false}(identity, n_out) ) end - ϕ = SimpleChains.init_params(g_chain, FloatType); - SimpleChainsApplicator(g_chain), ϕ + construct_SimpleChainsApplicator(g_chain, FloatType) end end # module diff --git a/src/DoubleMM/DoubleMM.jl b/src/DoubleMM/DoubleMM.jl index 33b535d..1487a18 100644 --- a/src/DoubleMM/DoubleMM.jl +++ b/src/DoubleMM/DoubleMM.jl @@ -1,14 +1,16 @@ module DoubleMM using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as HVI using ComponentArrays: ComponentArrays as CA using Random using Combinatorics using StatsFuns: logistic +using Bijectors +export f_doubleMM, xP_S1, xP_S2 include("f_doubleMM.jl") -export f_doubleMM, S1, S2 end \ No newline at end of file diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 4c92fdd..c69680d 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -1,67 +1,79 @@ struct DoubleMMCase <: AbstractHybridCase end -const S1 = [1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1] -const S2 = [1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] -θP = CA.ComponentVector(r0 = 0.3, K2 = 2.0) -θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2) +θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0) +θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2) + +transP = elementwise(exp) +transM = Stacked(elementwise(identity), elementwise(exp)) + const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM))) -function f_doubleMM(θ::AbstractVector) +function f_doubleMM(θ::AbstractVector, x) # extract parameters not depending on order, i.e whether they are in θP or θM θc = int_θdoubleMM(θ) r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] - y = r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) + y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) return (y) end -function HybridVariationalInference.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ()) +function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ()) (; θP, θM) end -function HybridVariationalInference.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) +function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ()) + (; transP, transM) +end + +function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) n_covar_pc = 2 n_covar = n_covar_pc + 3 # linear dependent - n_site = 10^n_covar_pc + #n_site = 10^n_covar_pc n_batch = 10 n_θM = length(θM) n_θP = length(θP) - (; n_covar, n_site, n_batch, n_θM, n_θP) + #(; n_covar, n_site, n_batch, n_θM, n_θP) + (; n_covar, n_batch, n_θM, n_θP) end -function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) - fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers +function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) + #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x) - pred_sites = applyf(fsite, θMs, θP, x) + pred_sites = applyf(f_doubleMM, θMs, θP, x) pred_global = eltype(pred_sites)[] return pred_global, pred_sites end end -function HybridVariationalInference.get_hybridcase_FloatType(::DoubleMMCase; scenario) - return Float32 -end +# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario) +# return Float32 +# end -function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; +const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1] +const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] + +function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; scenario = ()) n_covar_pc = 2 - (; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) + n_site = 200 + (; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) FloatType = get_hybridcase_FloatType(case; scenario) xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM; rhodec = 8, is_using_dropout = false) int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) # normalize to be distributed around the prescribed true values - θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1)) - f = gen_hybridcase_PBmodel(case; scenario) - xP = fill((), n_site) - y_global_true, y_true = f(θP, θMs_true, zip()) - σ_o = 0.01 + θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1))) + f = get_hybridcase_PBmodel(case; scenario) + xP = fill((;S1=xP_S1, S2=xP_S2), n_site) + y_global_true, y_true = f(θP, θMs_true, xP) + σ_o = FloatType(0.01) #σ_o = 0.002 - y_global_o = y_global_true .+ randn(rng, size(y_global_true)) .* σ_o - y_o = y_true .+ randn(rng, size(y_true)) .* σ_o + y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o + y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o (; xM, + n_site, θP_true = θP, θMs_true, xP, @@ -72,3 +84,6 @@ function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, σ_o = fill(σ_o, size(y_true,1)), ) end + + + diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl new file mode 100644 index 0000000..65c48c6 --- /dev/null +++ b/src/HybridProblem.jl @@ -0,0 +1,55 @@ +struct HybridProblem <: AbstractHybridCase + θP + θM + transP + transM + n_covar + n_batch + f + g + ϕg + train_loader + # inner constructor to constrain the types + function HybridProblem( + θP::CA.ComponentVector, θM::CA.ComponentVector, + g::AbstractModelApplicator, ϕg, + f::Function, + transM::Union{Function, Bijectors.Transform}, + transP::Union{Function, Bijectors.Transform}, + n_covar::Integer, n_batch::Integer, + train_loader::DataLoader) + new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader) + end +end + +function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ()) + (; θP = prob.θP, θM = prob.θM) +end + +function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ()) + n_θM = length(prob.θM) + n_θP = length(prob.θP) + (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) +end + +function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ()) + prob.f +end + +function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ()); + prob.g, prob.ϕg +end + +function get_hybridcase_train_dataloader( + prob::HybridProblem, rng::AbstractRNG = Random.default_rng(); + scenario = ()) + return(prob.train_loader) +end + + +# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ()) +# eltype(prob.θM) +# end + + + diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 370bdd0..a156030 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -11,6 +11,7 @@ using ChainRulesCore using Bijectors using Zygote # Zygote.@ignore CUDA.randn using BlockDiagonals +using MLUtils # dataloader export ComponentArrayInterpreter, flatten1, get_concrete include("ComponentArrayInterpreter.jl") @@ -22,10 +23,14 @@ include("ModelApplicator.jl") export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler include("GPUDataHandler.jl") -export AbstractHybridCase, gen_hybridcase_MLapplicator, gen_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic, - get_hybridcase_par_templates, gen_cov_pred +export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic, + get_hybridcase_par_templates, get_hybridcase_transforms, get_hybridcase_train_dataloader, + gen_cov_pred include("hybrid_case.jl") +export HybridProblem +include("HybridProblem.jl") + export applyf, gf, get_loss_gf include("gf.jl") diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index b32f686..1ada30e 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -1,3 +1,21 @@ +""" + AbstractModelApplicator(x, ϕ) + +Abstraction of applying a machine learning model at covariate matrix, `x`, +using parameters, `ϕ`. It returns a matrix of predictions with the same +number of rows as in `x`. + +Constructors for specifics are defined in extension packages. +Each constructor takes a special type of machine learning model and returns +a tuple with two components: +- The applicator +- a sample parameter vector (type depends on the used ML-framework) + +Implemented are +- `construct_SimpleChainsApplicator` +- `construct_FluxApplicator` +- `construct_LuxApplicator` +""" abstract type AbstractModelApplicator end function apply_model end diff --git a/src/elbo.jl b/src/elbo.jl index 13fbac0..e0cb2f9 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -12,22 +12,23 @@ expected value of the likelihood of observations. including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc), interpreted by interpreters.μP_ϕg_unc and interpreters.PMs - y_ob: matrix of observations (n_obs x n_site_batch) -- x: matrix of covariates (n_cov x n_site_batch) +- xM: matrix of covariates (n_cov x n_site_batch) +- xP: model drivers, iterable of (n_site_batch) - transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params - n_MC: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. - logσ2y: observation uncertainty (log of the variance) """ -function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractMatrix, - transPMs, interpreters::NamedTuple; +function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::AbstractMatrix, + xP, transPMs, interpreters::NamedTuple; n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(), entropyN = 0.0, ) - ζs, σ = generate_ζ(rng, g, f, ϕ, x, interpreters; n_MC) + ζs, σ = generate_ζ(rng, g, f, ϕ, xM, interpreters; n_MC) ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension #ζi = first(eachcol(ζs_cpu)) nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi - y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs) + y_pred_i, logjac = predict_y(ζi, xP, f, transPMs, interpreters.PMs) nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y) nLy1 - logjac end) / n_MC @@ -45,7 +46,7 @@ end Prediction function for hybrid model. Returns an Array `(n_obs, n_site, n_sample_pred)`. """ -function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters; +function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters; get_transPMs, get_ca_int_PMs, n_sample_pred=200, gpu_data_handler=get_default_GPUHandler()) n_site = size(xM, 2) @@ -56,7 +57,7 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret interpreters_gen; n_MC = n_sample_pred) ζs_cpu = gpu_data_handler(ζs) # y_pred = stack(map(ζ -> first(predict_y( - ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu))); + ζ, xP, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu))); y_pred end @@ -68,19 +69,19 @@ Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0` to the means extracted from parameters and predicted by the machine learning model. """ -function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix, +function generate_ζ(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpreters::NamedTuple; n_MC=3) # see documentation of neg_elbo_transnorm_gf ϕc = interpreters.μP_ϕg_unc(CA.getdata(ϕ)) μ_ζP = ϕc.μP ϕg = ϕc.ϕg - μ_ζMs0 = g(x, ϕg) # TODO provide μ_ζP to g + μ_ζMs0 = g(xM, ϕg) # TODO provide μ_ζP to g ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC) #ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC) ζ = stack(map(eachcol(ζ_resid)) do r rc = interpreters.PMs(r) ζP = μ_ζP .+ rc.P - μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g + μ_ζMs = μ_ζMs0 # g(xM, ϕc.ϕ) # TODO provide ζP to g ζMs = μ_ζMs .+ rc.Ms vcat(ζP, vec(ζMs)) end) @@ -168,13 +169,13 @@ Steps: - transform the parameters to original constrained space - Applies the mechanistic model for each site """ -function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter) +function predict_y(ζi, xP, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter) # θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating # θc = CA.ComponentVector(θtup) θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating θc = int_PMs(θ) # TODO provide xP - xP = fill((), size(θc.Ms,2)) + # xP = fill((), size(θc.Ms,2)) y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU # TODO take care of y_pred_global y_pred, logjac diff --git a/src/gf.jl b/src/gf.jl index 84a912b..c86098e 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -1,5 +1,5 @@ function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, x) - # predict several sites with same physical parameters + # predict several sites with same global parameters θP yv = map(eachcol(θMs), x) do θM, x_site f(vcat(θP, θM), x_site) end diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index 7c4ee3d..92b3ed1 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -3,13 +3,15 @@ Type to dispatch constructing data and network structures for different cases of hybrid problem setups For a specific case, provide functions that specify details -- get_hybridcase_par_templates -- get_hybridcase_sizes -- gen_hybridcase_MLapplicator -- gen_hybridcase_PBmodel +- `get_hybridcase_par_templates` +- `get_hybridcase_transforms` +- `get_hybridcase_sizes` +- `get_hybridcase_MLapplicator` +- `get_hybridcase_PBmodel` +- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`) optionally -- gen_hybridcase_synthetic -- get_hybridcase_FloatType (if it should differ from Float32) +- `gen_hybridcase_synthetic` +- `get_hybridcase_FloatType` (defaults to eltype(θM)) """ abstract type AbstractHybridCase end; @@ -20,6 +22,16 @@ Provide tuple of templates of ComponentVectors `θP` and `θM`. """ function get_hybridcase_par_templates end + +""" + get_hybridcase_transforms(::AbstractHybridCase; scenario) + +Return a NamedTupe of +- `transP`: Bijectors.Transform for the global PBM parameters, θP +- `transM`: Bijectors.Transform for the single-site PBM parameters, θM +""" +function get_hybridcase_transforms end + """ get_hybridcase_par_templates(::AbstractHybridCase; scenario) @@ -32,7 +44,7 @@ Provide a NamedTuple of number of function get_hybridcase_sizes end """ - gen_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=()) + get_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=()) Construct the machine learning model fro given problem case and ML-Framework and scenario. @@ -44,10 +56,10 @@ returns a Tuple of - AbstractModelApplicator - initial parameter vector """ -function gen_hybridcase_MLapplicator end +function get_hybridcase_MLapplicator end """ - gen_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=()) + get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=()) Construct the process-based model function `f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)` @@ -60,7 +72,7 @@ returns a tuple of predictions with components - first, those that are constant across sites - second, those that vary across sites, with a column for each site """ -function gen_hybridcase_PBmodel end +function get_hybridcase_PBmodel end """ gen_hybridcase_synthetic(::AbstractHybridCase, rng; scenario) @@ -81,6 +93,26 @@ function gen_hybridcase_synthetic end Determine the FloatType for given Case and scenario, defaults to Float32 """ -function get_hybridcase_FloatType(::AbstractHybridCase; scenario) - return Float32 +function get_hybridcase_FloatType(case::AbstractHybridCase; scenario) + return eltype(get_hybridcase_par_templates(case; scenario).θM) end + +""" + get_hybridcase_train_dataloader(::AbstractHybridCase, rng; scenario) + +Return a DataLoader that provides a tuple of +- `xM`: matrix of covariates, with one column per site +- `xP`: Iterator of process-model drivers, with one element per site +- `y_o`: matrix of observations with added noise, with one column per site +""" +function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::AbstractRNG; + scenario = ()) + (; xM, xP, y_o) = gen_hybridcase_synthetic(case, rng; scenario) + (; n_batch) = get_hybridcase_sizes(case; scenario) + xM_gpu = :use_flux ∈ scenario ? CuArray(xM) : xM + train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch) + return(train_loader) +end + + + diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index d010916..7480399 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -12,7 +12,7 @@ Returns a NamedTuple of # Arguments - `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters -- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator` +- `ϕg`: vector of parameters to optimize, as returned by `get_hybridcase_MLapplicator` - `n_batch`: the number of sites to predicted in each mini-batch - `transP`, `transM`: the Bijector.Transformations for the global and site-dependent parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`. @@ -27,12 +27,13 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; # check translating parameters - can match length? _ = Bijectors.inverse(transP)(θP) _ = Bijectors.inverse(transM)(θM) + FT = eltype(θM) # zero correlation matrices - ρsP = zeros(sum(1:(n_θP - 1))) - ρsM = zeros(sum(1:(n_θM - 1))) + ρsP = zeros(FT, sum(1:(n_θP - 1))) + ρsM = zeros(FT, sum(1:(n_θM - 1))) ϕunc0 = CA.ComponentVector(; - logσ2_logP = fill(-10.0, n_θP), - coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), + logσ2_logP = fill(FT(-10.0), n_θP), + coef_logσ2_logMs = reduce(hcat, (FT[-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) ϕ = CA.ComponentVector(; diff --git a/test/runtests.jl b/test/runtests.jl index 50635e6..78ec965 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml @time @safetestset "test_logden_normal" include("test_logden_normal.jl") #@safetestset "test" include("test/test_doubleMM.jl") @time @safetestset "test_doubleMM" include("test_doubleMM.jl") + #@safetestset "test" include("test/test_HybridProblem.jl") + @time @safetestset "test_HybridProblem" include("test_HybridProblem.jl") #@safetestset "test" include("test/test_cholesky_structure.jl") @time @safetestset "test_cholesky_structure" include("test_cholesky_structure.jl") #@safetestset "test" include("test/test_sample_zeta.jl") @@ -31,6 +33,7 @@ end if GROUP == "All" || GROUP == "Aqua" #@safetestset "test" include("test/test_aqua.jl") if VERSION >= VersionNumber("1.11.2") + #@safetestset "test" include("test/test_aqua.jl") @time @safetestset "test_aqua" include("test_aqua.jl") end end diff --git a/test/test_Flux.jl b/test/test_Flux.jl index ad49eb8..6aa62c3 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -35,16 +35,19 @@ end; Dense(n_covar * 4 => n_covar * 4, tanh), Dense(n_covar * 4 => n_out, identity, bias=false), ) - g = construct_FluxApplicator(g_chain) + g, ϕg = construct_FluxApplicator(g_chain |> f64) + @test eltype(ϕg) == Float64 + g, ϕg = construct_FluxApplicator(g_chain) + @test eltype(ϕg) == Float32 n_site = 3 x = rand(Float32, n_covar, n_site) - ϕ, _rebuild = destructure(g_chain) - y = g(x, ϕ) + #ϕ, _rebuild = destructure(g_chain) + y = g(x, ϕg) @test size(y) == (n_out, n_site) # n_site = 3 x = rand(Float32, n_covar, n_site) |> gpu - ϕ = ϕ |> gpu + ϕ = ϕg |> gpu y = g(x, ϕ) #@test ϕ isa GPUArraysCore.AbstractGPUArray @test size(y) == (n_out, n_site) diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl new file mode 100644 index 0000000..c7757c1 --- /dev/null +++ b/test/test_HybridProblem.jl @@ -0,0 +1,95 @@ +using Test +using HybridVariationalInference +using StableRNGs +using Random +using Statistics +using ComponentArrays: ComponentArrays as CA +using Bijectors + +using SimpleChains +using MLUtils +import Zygote + +using OptimizationOptimisers + +const MLengine = Val(nameof(SimpleChains)) + +construct_problem = () -> begin + θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0) + θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2) + transP = elementwise(exp) + transM = Stacked(elementwise(identity), elementwise(exp)) + n_covar = 5 + n_batch = 10 + int_θdoubleMM = get_concrete(ComponentArrayInterpreter( + flatten1(CA.ComponentVector(; θP, θM)))) + function f_doubleMM(θ::AbstractVector, x) + # extract parameters not depending on order, i.e whether they are in θP or θM + θc = int_θdoubleMM(θ) + r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] + y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) + return (y) + end + function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x) + pred_sites = applyf(f_doubleMM, θMs, θP, x) + pred_global = eltype(pred_sites)[] + return pred_global, pred_sites + end + n_out = length(θM) + g_chain = SimpleChain( + static(n_covar), # input dimension (optional) + # dense layer with bias that maps to 8 outputs and applies `tanh` activation + TurboDense{true}(tanh, n_covar * 4), + TurboDense{true}(tanh, n_covar * 4), + # dense layer without bias that maps to n outputs and `identity` activation + TurboDense{false}(identity, n_out) + ) + # g, ϕg = construct_SimpleChainsApplicator(g_chain) + # + rng = StableRNG(111) + (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng;) + train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) + # HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, + # g, ϕg, train_loader) + HybridProblem(θP, θM, g_chain, f_doubleMM_with_global, + transM, transP, n_covar, n_batch, train_loader) +end +prob = construct_problem(); +scenario = (:default,) + +#(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario) + +@testset "loss_gf" begin + #----------- fit g and θP to y_o + g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario) + train_loader = get_hybridcase_train_dataloader(prob; scenario) + (xM, xP, y_o) = first(train_loader) + f = get_hybridcase_PBmodel(prob; scenario) + par_templates = get_hybridcase_par_templates(prob; scenario) + + int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( + ϕg = 1:length(ϕg0), θP = par_templates.θP)) + p = p0 = vcat(ϕg0, par_templates.θP .* 0.8) # slightly disturb θP_true + + # Pass the site-data for the batches as separate vectors wrapped in a tuple + + y_global_o = Float64[] + loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) + l1 = loss_gf(p0, first(train_loader)...) + gr = Zygote.gradient(p -> loss_gf(p, train_loader.data...)[1], p0) + @test gr[1] isa Vector + + () -> begin + optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], + Optimization.AutoZygote()) + optprob = OptimizationProblem(optf, p0, train_loader) + + res = Optimization.solve( + # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000); + optprob, Adam(0.02), maxiters = 1000) + + l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) + @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) + end +end diff --git a/test/test_Lux.jl b/test/test_Lux.jl index baa90f7..d80da03 100644 --- a/test/test_Lux.jl +++ b/test/test_Lux.jl @@ -1,8 +1,8 @@ using HybridVariationalInference using Test +using CUDA, GPUArraysCore using Lux using StatsFuns: logistic -using CUDA, GPUArraysCore @testset "LuxModelApplicator" begin @@ -13,18 +13,20 @@ using CUDA, GPUArraysCore Dense(n_covar * 4 => n_covar * 4, tanh), Dense(n_covar * 4 => n_out, logistic, use_bias=false), ); - g = construct_LuxApplicator(g_chain; device = cpu_device()); + g, ϕ = construct_LuxApplicator(g_chain, Float64; device = cpu_device()); + @test eltype(ϕ) == Float64 + g, ϕ = construct_LuxApplicator(g_chain; device = cpu_device()); + @test eltype(ϕ) == Float32 n_site = 3 x = rand(Float32, n_covar, n_site) - ϕ = randn(Float32, Lux.parameterlength(g_chain)) + #ϕ = randn(Float32, Lux.parameterlength(g_chain)) y = g(x, ϕ) @test size(y) == (n_out, n_site) # - g = construct_LuxApplicator(g_chain; device = gpu_device()); - n_site = 3 x = rand(Float32, n_covar, n_site) |> gpu_device() - ϕ = randn(Float32, Lux.parameterlength(g_chain)) |> gpu_device() - y = g(x, ϕ) + ϕ_gpu = ϕ |> gpu_device() + #ϕ = randn(Float32, Lux.parameterlength(g_chain)) |> gpu_device() + y = g(x, ϕ_gpu) #@test ϕ isa GPUArraysCore.AbstractGPUArray @test size(y) == (n_out, n_site) end; diff --git a/test/test_SimpleChains.jl b/test/test_SimpleChains.jl index 6036f1e..29adb37 100644 --- a/test/test_SimpleChains.jl +++ b/test/test_SimpleChains.jl @@ -12,10 +12,10 @@ using StatsFuns: logistic TurboDense{true}(tanh, n_covar * 4), TurboDense{false}(logistic, n_out) ) - g = construct_SimpleChainsApplicator(g_chain) + g, ϕg = construct_SimpleChainsApplicator(g_chain) n_site = 3 x = rand(n_covar, n_site) - ϕ = SimpleChains.init_params(g_chain); - y = g(x, ϕ) + #ϕg = SimpleChains.init_params(g_chain); + y = g(x, ϕg) @test size(y) == (n_out, n_site) end; diff --git a/test/test_cholesky_structure.jl b/test/test_cholesky_structure.jl index 58a8624..b02e07e 100644 --- a/test/test_cholesky_structure.jl +++ b/test/test_cholesky_structure.jl @@ -247,8 +247,8 @@ end #@test Upred ≈ CU SUpred = Upred * Dσ #hcat(SUpred, SU) - @test SUpred≈SU atol=2e-1 + @test SUpred≈SU atol=6e-1 S_pred = Dσ' * Upred' * Upred * Dσ - @test S_pred≈S atol=2e-1 + @test S_pred≈S atol=6e-1 end diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 8025abf..8e6c5a3 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -11,16 +11,16 @@ import Zygote using OptimizationOptimisers -const case = DoubleMM.DoubleMMCase() const MLengine = Val(nameof(SimpleChains)) +const case = DoubleMM.DoubleMMCase() scenario = (:default,) par_templates = get_hybridcase_par_templates(case; scenario) -(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) rng = StableRNG(111) -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o ) = gen_hybridcase_synthetic(case, rng; scenario); @testset "gen_hybridcase_synthetic" begin @@ -36,7 +36,7 @@ rng = StableRNG(111) end @testset "loss_g" begin - g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); + g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); function loss_g(ϕg, x, g) ζMs = g(x, ϕg) # predict the log of the parameters @@ -61,16 +61,17 @@ end end @testset "loss_gf" begin - #----------- fit g and θP to y_o - g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); - f = gen_hybridcase_PBmodel(case; scenario) + #----------- fit g and θP to y_o (without transformations) + g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); + f = get_hybridcase_PBmodel(case; scenario) int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), θP = par_templates.θP)) p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true # Pass the site-data for the batches as separate vectors wrapped in a tuple - train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) + #train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) + train_loader = get_hybridcase_train_dataloader(case, rng; scenario) loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) l1 = loss_gf(p0, train_loader.data...)[1] diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 0a9f2f0..37bccc2 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -19,20 +19,22 @@ rng = StableRNG(111) const case = DoubleMM.DoubleMMCase() const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) +FT = get_hybridcase_FloatType(case; scenario) #θsite_true = get_hybridcase_par_templates(case; scenario) -g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); -f = gen_hybridcase_PBmodel(case; scenario) +g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); +f = get_hybridcase_PBmodel(case; scenario) -(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o ) = gen_hybridcase_synthetic(case, rng; scenario); -logσ2y = 2 .* log.(σ_o) +logσ2y = FT(2) .* log.(σ_o) n_MC = 3 -transP = elementwise(exp) -transM = Stacked(elementwise(identity), elementwise(exp)) +(; transP, transM) = get_hybridcase_transforms(case; scenario) +# transP = elementwise(exp) +# transM = Stacked(elementwise(identity), elementwise(exp)) #transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM); @@ -117,7 +119,7 @@ end; # setup g as FluxNN on gpu using Flux FluxMLengine = Val(nameof(Flux)) -g_flux, ϕg0_flux_cpu = gen_hybridcase_MLapplicator(case, FluxMLengine; scenario) +g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario) if CUDA.functional() @testset "generate_ζ gpu" begin @@ -127,6 +129,7 @@ if CUDA.functional() rng, g_flux, f, ϕ, xMg_batch, map(get_concrete, interpreters); n_MC = 8) @test ζ isa CuMatrix + @test eltype(ζ) == FT gr = Zygote.gradient( ϕ -> sum(CP.generate_ζ( rng, g_flux, f, ϕ, xMg_batch, map(get_concrete, interpreters); @@ -137,13 +140,14 @@ if CUDA.functional() end @testset "neg_elbo_transnorm_gf cpu" begin - cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o[:, 1:n_batch], xM[:, 1:n_batch], - transPMs_batch, map(get_concrete, interpreters); + cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o[:, 1:n_batch], + xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters); n_MC = 8, logσ2y) @test cost isa Float64 gr = Zygote.gradient( ϕ -> neg_elbo_transnorm_gf( - rng, g, f, ϕ, y_o[:, 1:n_batch], xM[:, 1:n_batch], + rng, g, f, ϕ, y_o[:, 1:n_batch], + xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, interpreters; n_MC = 8, logσ2y), CA.getdata(ϕ_ini)) @test gr[1] isa Vector @@ -153,16 +157,20 @@ if CUDA.functional() @testset "neg_elbo_transnorm_gf gpu" begin ϕ = CuArray(CA.getdata(ϕ_ini)) xMg_batch = CuArray(xM[:, 1:n_batch]) - cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch], xMg_batch, + xP_batch = xP[1:n_batch] # used in f which runs on CPU + cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch], + xMg_batch, xP_batch, transPMs_batch, map(get_concrete, interpreters); n_MC = 8, logσ2y) @test cost isa Float64 gr = Zygote.gradient( ϕ -> neg_elbo_transnorm_gf( - rng, g_flux, f, ϕ, y_o[:, 1:n_batch], xMg_batch, + rng, g_flux, f, ϕ, y_o[:, 1:n_batch], + xMg_batch, xP_batch, transPMs_batch, interpreters; n_MC = 8, logσ2y), ϕ) @test gr[1] isa CuVector + @test eltype(gr[1]) == FT end end @@ -172,7 +180,7 @@ end trans_PMs_gen = get_transPMs(n_site) @test length(intm_PMs_gen) == 402 @test trans_PMs_gen.length_in == 402 - y_pred = predict_gf(rng, g, f, ϕ_ini, xM, map(get_concrete, interpreters); + y_pred = predict_gf(rng, g, f, ϕ_ini, xM, xP, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred) @test y_pred isa Array @test size(y_pred) == (size(y_o)..., n_sample_pred) @@ -183,7 +191,7 @@ if CUDA.functional() n_sample_pred = 200 ϕ = CuArray(CA.getdata(ϕ_ini)) xMg = CuArray(xM) - y_pred = predict_gf(rng, g_flux, f, ϕ, xMg, map(get_concrete, interpreters); + y_pred = predict_gf(rng, g_flux, f, ϕ, xMg, xP, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred) @test y_pred isa Array @test size(y_pred) == (size(y_o)..., n_sample_pred) diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 1e01dd3..392e76b 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -19,10 +19,10 @@ const case = DoubleMM.DoubleMMCase() #const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) -(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) @testset "test_sample_zeta" begin - (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o + (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o ) = gen_hybridcase_synthetic(case, rng; scenario) # n_site = 2