From 151551df0ef9d43cf7e81acddf44358a0f29c280 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 23 Jan 2025 17:06:14 +0100 Subject: [PATCH 1/7] enable independent subranges in parameter s --- dev/doubleMM.jl | 2 +- src/HybridProblem.jl | 19 +++-- src/HybridVariationalInference.jl | 2 +- src/cholesky.jl | 49 +++++++++++ src/elbo.jl | 29 +++---- src/hybrid_case.jl | 26 +++++- test/test_HybridProblem.jl | 56 ++++++++++++- test/test_cholesky_structure.jl | 53 +++++++++--- test/test_elbo.jl | 8 +- test/test_sample_zeta.jl | 134 +++++++++++++++--------------- 10 files changed, 272 insertions(+), 106 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 5cec624..75b43bf 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -282,7 +282,7 @@ histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_tru # look at θP, θM1 of first site intm_PMs_gen = get_ca_int_PMs(n_site) -ζs, _σ = HVI.generate_ζ(rng, g_flux, f, res.u, xM_gpu, +ζs, _σ = HVI.generate_ζ(rng, g_flux, res.u, xM_gpu, (; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred); ζs = ζs |> Flux.cpu; θPM = vcat(θP_true, θMs_true[:, 1]) diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 65c48c6..1dcc197 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -1,13 +1,14 @@ struct HybridProblem <: AbstractHybridCase θP θM + f + g + ϕg transP transM + cor_starts # = (P=(1,),M=(1,)) n_covar n_batch - f - g - ϕg train_loader # inner constructor to constrain the types function HybridProblem( @@ -17,8 +18,9 @@ struct HybridProblem <: AbstractHybridCase transM::Union{Function, Bijectors.Transform}, transP::Union{Function, Bijectors.Transform}, n_covar::Integer, n_batch::Integer, - train_loader::DataLoader) - new(θP, θM, transM, transP, n_covar, n_batch, f, g, ϕg, train_loader) + train_loader::DataLoader, + cor_starts = (P=(1,), M=(1,))) + new(θP, θM, f, g, ϕg, transM, transP, cor_starts, n_covar, n_batch, train_loader) end end @@ -26,6 +28,10 @@ function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = () (; θP = prob.θP, θM = prob.θM) end +function get_hybridcase_transforms(prob::HybridProblem; scenario::NTuple = ()) + (; transP = prob.transP, transM = prob.transM) +end + function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ()) n_θM = length(prob.θM) n_θP = length(prob.θP) @@ -46,6 +52,9 @@ function get_hybridcase_train_dataloader( return(prob.train_loader) end +function get_hybridcase_cor_starts(prob::HybridProblem; scenario = ()) + prob.cor_starts +end # function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ()) # eltype(prob.θM) diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index a156030..e439de5 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -43,7 +43,7 @@ include("util_opt.jl") export neg_logden_indep_normal, entropy_MvNormal include("logden_normal.jl") -#export - all internal +export get_ca_starts include("cholesky.jl") export neg_elbo_transnorm_gf, predict_gf diff --git a/src/cholesky.jl b/src/cholesky.jl index 5432d3a..f0ff914 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -252,6 +252,55 @@ function transformU_cholesky1(v::GPUArraysCore.AbstractGPUVector; n=invsumn(leng return U end +# function transformU_block_cholesky1(v::CA.ComponentVector; +# ns=(invsumn(length(v[k])) + 1 for k in keys(v)) # may pass for efficiency +# ) +# blocks = [transformU_cholesky1(v[k]; n) for (k, n) in zip(keys(v), ns)] +# U = _create_blockdiag(v[first(keys(v))], blocks) # v only for dispatch: plain matrix for gpu +# end + + +""" + get_ca_starts(vc::ComponentVector) + +Return a tuple with starting positions of components in vc. +Useful for providing information on correlactions among subranges in a vector. +""" +function get_ca_starts(vc::CA.ComponentVector) + (1, (1 .+ cumsum((length(vc[k]) for k in front(keys(vc)))))...) +end +"omit the last n elements of an iterator" +front(itr, n=1) = Iterators.take(itr, length(itr)-n) + +""" + transformU_block_cholesky1(v::AbstractVector, cor_starts = (1,)) + +Transform a parameterization v of a blockdiagonal of upper triangular matrices +into the this matrix. +`cor_starts` is a NTuple of Integeres specifying the first column of each block. +E.g. For a matrix with a 3x3, a 2x2, and another block, +the blocks start at colums (1,4,6). It defaults to a single entire block. +""" +function transformU_block_cholesky1(v::AbstractVector, cor_starts = (1,)) + cor_starts_end = (cor_starts..., length(v)+1) + ranges = ChainRulesCore.@ignore_derivatives ( + cor_starts_end[i]:(cor_starts_end[i+1]-1) for i in 1:length(cor_starts)) + blocks = [transformU_cholesky1(v[r]) for r in ranges] + U = _create_blockdiag(v, blocks) # v only for dispatch: plain matrix for gpu + return(U) +end + +function _create_blockdiag(::AbstractArray, blocks) + BlockDiagonal(blocks) +end + +function _create_blockdiag(::GPUArraysCore.AbstractGPUArray, blocks) + # impose no special structure + cat(blocks...; dims=(1, 2)) +end + + + () -> begin tmp = sqrt.(sum(abs2, U_scaled, dims=1)) tmp2 = sum(abs2, U_scaled, dims=1) .^ (-1 / 2) diff --git a/src/elbo.jl b/src/elbo.jl index e0cb2f9..4e83a1e 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -22,9 +22,9 @@ expected value of the likelihood of observations. function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::AbstractMatrix, xP, transPMs, interpreters::NamedTuple; n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(), - entropyN = 0.0, + cor_starts=(P=(1,),M=(1,)) ) - ζs, σ = generate_ζ(rng, g, f, ϕ, xM, interpreters; n_MC) + ζs, σ = generate_ζ(rng, g, ϕ, xM, interpreters; n_MC, cor_starts) ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension #ζi = first(eachcol(ζs_cpu)) nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi @@ -48,13 +48,14 @@ Prediction function for hybrid model. Returns an Array `(n_obs, n_site, n_sample """ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, xP, interpreters; get_transPMs, get_ca_int_PMs, n_sample_pred=200, - gpu_data_handler=get_default_GPUHandler()) + gpu_data_handler=get_default_GPUHandler(), + cor_starts=(P=(1,),M=(1,))) n_site = size(xM, 2) intm_PMs_gen = get_ca_int_PMs(n_site) trans_PMs_gen = get_transPMs(n_site) interpreters_gen = (; interpreters..., PMs = intm_PMs_gen) - ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM), - interpreters_gen; n_MC = n_sample_pred) + ζs, _ = generate_ζ(rng, g, CA.getdata(ϕ), CA.getdata(xM), + interpreters_gen; n_MC = n_sample_pred, cor_starts) ζs_cpu = gpu_data_handler(ζs) # y_pred = stack(map(ζ -> first(predict_y( ζ, xP, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu))); @@ -69,14 +70,14 @@ Adds the MV-normally distributed residuals, retrieved by `sample_ζ_norm0` to the means extracted from parameters and predicted by the machine learning model. """ -function generate_ζ(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, - interpreters::NamedTuple; n_MC=3) +function generate_ζ(rng, g, ϕ::AbstractVector, xM::AbstractMatrix, + interpreters::NamedTuple; n_MC=3, cor_starts=(P=(1,),M=(1,))) # see documentation of neg_elbo_transnorm_gf ϕc = interpreters.μP_ϕg_unc(CA.getdata(ϕ)) μ_ζP = ϕc.μP ϕg = ϕc.ϕg μ_ζMs0 = g(xM, ϕg) # TODO provide μ_ζP to g - ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC) + ζ_resid, σ = sample_ζ_norm0(rng, μ_ζP, μ_ζMs0, ϕc.unc; n_MC, cor_starts) #ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC) ζ = stack(map(eachcol(ζ_resid)) do r rc = interpreters.PMs(r) @@ -98,21 +99,21 @@ ComponentMarshellers - marsh_batch(n_batch) - marsh_unc(n_UncP, n_UncM, n_UncCorr) """ -function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, ϕunc::AbstractVector, args...; - n_MC=3) +function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, + args...; n_MC, cor_starts) n_θP, n_θMs = length(ζP), length(ζMs) urand = _create_random(rng, CA.getdata(ζP), n_θP + n_θMs, n_MC) - sample_ζ_norm0(urand, ζP, ζMs, ϕunc, args...) + sample_ζ_norm0(urand, ζP, ζMs, args...; cor_starts) end function sample_ζ_norm0(urand::AbstractMatrix, ζP::AbstractVector{T}, ζMs::AbstractMatrix, - ϕunc::AbstractVector, int_unc = ComponentArrayInterpreter(ϕunc); + ϕunc::AbstractVector, int_unc = ComponentArrayInterpreter(ϕunc); cor_starts ) where {T} ϕuncc = int_unc(CA.getdata(ϕunc)) n_θP, n_θMs, (n_θM, n_batch) = length(ζP), length(ζMs), size(ζMs) # make sure to not create a UpperTriangular Matrix of an CuArray in transformU_cholesky1 - UP = transformU_cholesky1(ϕuncc.ρsP) - UM = transformU_cholesky1(ϕuncc.ρsM) + UP = transformU_block_cholesky1(ϕuncc.ρsP, cor_starts.P) + UM = transformU_block_cholesky1(ϕuncc.ρsM, cor_starts.M) cf = ϕuncc.coef_logσ2_logMs logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs) logσ2_logP = vec(CA.getdata(ϕuncc.logσ2_logP)) diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index 92b3ed1..02a3ec4 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -11,7 +11,8 @@ For a specific case, provide functions that specify details - `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`) optionally - `gen_hybridcase_synthetic` -- `get_hybridcase_FloatType` (defaults to eltype(θM)) +- `get_hybridcase_FloatType` (defaults to `eltype(θM)`) +- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`) """ abstract type AbstractHybridCase end; @@ -93,7 +94,7 @@ function gen_hybridcase_synthetic end Determine the FloatType for given Case and scenario, defaults to Float32 """ -function get_hybridcase_FloatType(case::AbstractHybridCase; scenario) +function get_hybridcase_FloatType(case::AbstractHybridCase; scenario=()) return eltype(get_hybridcase_par_templates(case; scenario).θM) end @@ -114,5 +115,26 @@ function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::Abstract return(train_loader) end +""" + get_hybridcase_cor_starts(case::AbstractHybridCase; scenario) + +Specify blocks in correlation matrices among parameters. +Returns a NamedTuple. +- `P`: correlations among global parameters +- `M`: correlations among ML-predicted parameters + +Subsets ofparameters that are correlated with other but not correlated with +parameters of other subranges are specified by indicating the starting position +of each subrange. +E.g. if withing global parameter vector `(p1, p2, p3)`, `p1` and `p2` are correlated, +but parameter `p3` is not correlated with them, +then the first subrange starts at position 1 and the second subrange starts at position 3. +If there is only single block of all ML-predicted parameters being correlated +with each other then this block starts at position 1: `(P=(1,3), M=(1,))`. +""" +function get_hybridcase_cor_starts(case::AbstractHybridCase; scenario = ()) + (P=(1,), M=(1,)) +end + diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index c7757c1..215f0ff 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -12,6 +12,7 @@ import Zygote using OptimizationOptimisers + const MLengine = Val(nameof(SimpleChains)) construct_problem = () -> begin @@ -19,6 +20,7 @@ construct_problem = () -> begin θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2) transP = elementwise(exp) transM = Stacked(elementwise(identity), elementwise(exp)) + cov_starts = (P=(1,2),M=(1)) # assume r0 independent of K2 n_covar = 5 n_batch = 10 int_θdoubleMM = get_concrete(ComponentArrayInterpreter( @@ -53,7 +55,7 @@ construct_problem = () -> begin # HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, # g, ϕg, train_loader) HybridProblem(θP, θM, g_chain, f_doubleMM_with_global, - transM, transP, n_covar, n_batch, train_loader) + transM, transP, n_covar, n_batch, train_loader, cov_starts) end prob = construct_problem(); scenario = (:default,) @@ -93,3 +95,55 @@ scenario = (:default,) @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) end end + +() -> begin +@testset "neg_elbo_transnorm_gf cpu" begin + rng = StableRNG(111) + g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine); + train_loader = get_hybridcase_train_dataloader(prob) + (xM, xP, y_o) = first(train_loader) + n_batch = size(y_o,2) + f = get_hybridcase_PBmodel(prob) + (θP0, θM0) = get_hybridcase_par_templates(prob) + (; transP, transM) = get_hybridcase_transforms(prob) + + (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( + θP0, θM0, ϕg0, n_batch; transP, transM); + ϕ_ini = ϕ + + cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o, + xM, xP, transPMs_batch, map(get_concrete, interpreters); + n_MC = 8, logσ2y) + @test cost isa Float64 + gr = Zygote.gradient( + ϕ -> neg_elbo_transnorm_gf( + rng, g, f, ϕ, y_o[:, 1:n_batch], + xM[:, 1:n_batch], xP[1:n_batch], + transPMs_batch, interpreters; n_MC = 8, logσ2y), + CA.getdata(ϕ_ini)) + @test gr[1] isa Vector +end; + +if CUDA.functional() + @testset "neg_elbo_transnorm_gf gpu" begin + ϕ = CuArray(CA.getdata(ϕ_ini)) + xMg_batch = CuArray(xM[:, 1:n_batch]) + xP_batch = xP[1:n_batch] # used in f which runs on CPU + cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch], + xMg_batch, xP_batch, + transPMs_batch, map(get_concrete, interpreters); + n_MC = 8, logσ2y) + @test cost isa Float64 + gr = Zygote.gradient( + ϕ -> neg_elbo_transnorm_gf( + rng, g_flux, f, ϕ, y_o[:, 1:n_batch], + xMg_batch, xP_batch, + transPMs_batch, interpreters; n_MC = 8, logσ2y), + ϕ) + @test gr[1] isa CuVector + @test eltype(gr[1]) == FT + end +end +end #if false + + diff --git a/test/test_cholesky_structure.jl b/test/test_cholesky_structure.jl index b02e07e..c6e6e51 100644 --- a/test/test_cholesky_structure.jl +++ b/test/test_cholesky_structure.jl @@ -75,9 +75,9 @@ end; U1v = CP.vec2uutri(v_orig) @test U1v isa UnitUpperTriangular @test size(U1v, 1) == 4 - gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v))) , vcpu)[1] # works nice + gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v))), vcpu)[1] # works nice # test providing keyword argument - gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v; n=4))) , vcpu)[1] # works nice + gr = Zygote.gradient(v -> sum(abs2.(CP.vec2uutri(v; n = 4))), vcpu)[1] # works nice # v2 = CP.uutri2vec(U1v) @test v2 == v_orig @@ -85,15 +85,15 @@ end; end; @testset "utri2vec_pos" begin - @test CP.utri2vec_pos(1,1) == 1 - @test CP.utri2vec_pos(1,2) == 2 - @test CP.utri2vec_pos(2,2) == 3 - @test CP.utri2vec_pos(1,3) == 4 - @test CP.utri2vec_pos(1,4) == 7 - @test CP.utri2vec_pos(5,5) == 15 - typeof(CP.utri2vec_pos(5,5)) == Int - typeof(CP.utri2vec_pos(Int32(5),Int32(5))) == Int32 - @test_throws AssertionError CP.utri2vec_pos(2,1) + @test CP.utri2vec_pos(1, 1) == 1 + @test CP.utri2vec_pos(1, 2) == 2 + @test CP.utri2vec_pos(2, 2) == 3 + @test CP.utri2vec_pos(1, 3) == 4 + @test CP.utri2vec_pos(1, 4) == 7 + @test CP.utri2vec_pos(5, 5) == 15 + typeof(CP.utri2vec_pos(5, 5)) == Int + typeof(CP.utri2vec_pos(Int32(5), Int32(5))) == Int32 + @test_throws AssertionError CP.utri2vec_pos(2, 1) end @testset "vec2uutri gpu" begin @@ -137,6 +137,36 @@ end; end end; +@testset "transformU_block_cholesky1 gpu" begin + vc = CA.ComponentVector(b1 = [1.0f0], b2 = [2.0f0:5.0f0]) + vc = CA.ComponentVector(b1 = [1.0f0:3.0f0]) + vc = CA.ComponentVector(b1 = 1.0f0:3.0f0, b2 = [5.0f0]) + v = CA.getdata(vc) + cor_starts = get_ca_starts(vc) + #ns=(CP.invsumn(length(v[k])) + 1 for k in keys(v)) + #collect(ns) + U = CP.transformU_block_cholesky1(v, cor_starts) + @test diag(U' * U) ≈ ones(5) + @test U[1:3, 4:5] ≈ zeros(3, 2) + gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_starts)), v)[1] # works nice + if CUDA.functional() # only run the test, if CUDA is working (not on Github ci) + vc = v_orig = CA.ComponentVector(b1 = CuArray(1.0f0:3.0f0), b2 = CuArray([5.0f0])) + v = CA.getdata(vc) + cor_starts = get_ca_starts(vc) + U = CP.transformU_block_cholesky1(v, cor_starts) + @test U isa CuArray + @test diag(Array(U' * U)) ≈ ones(5) + @test Array(U[1:3, 4:5]) ≈ zeros(3, 2) + gr1 = Zygote.gradient(v -> sum(CP.transformU_block_cholesky1(v, cor_starts)), v)[1] # works nice + end +end; + +() -> begin + cor_starts = (1,) + cor_starts_end = (cor_starts..., length(v) + 1) + [cor_starts_end[i]:(cor_starts_end[i + 1] - 1) for i in 1:length(cor_starts)] +end + () -> begin #setup for fitting of interactive blocks below _X = rand(3, 3) @@ -251,4 +281,3 @@ end S_pred = Dσ' * Upred' * Upred * Dσ @test S_pred≈S atol=6e-1 end - diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 37bccc2..6b15d64 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -105,12 +105,12 @@ end @testset "generate_ζ" begin ζ, σ = CP.generate_ζ( - rng, g, f, ϕ_ini, xM[:, 1:n_batch], map(get_concrete, interpreters); + rng, g, ϕ_ini, xM[:, 1:n_batch], map(get_concrete, interpreters); n_MC = 8) @test ζ isa Matrix gr = Zygote.gradient( ϕ -> sum(CP.generate_ζ( - rng, g, f, ϕ, xM[:, 1:n_batch], map(get_concrete, interpreters); + rng, g, ϕ, xM[:, 1:n_batch], map(get_concrete, interpreters); n_MC = 8)[1]), CA.getdata(ϕ_ini)) @test gr[1] isa Vector @@ -126,13 +126,13 @@ if CUDA.functional() ϕ = CuArray(CA.getdata(ϕ_ini)) xMg_batch = CuArray(xM[:, 1:n_batch]) ζ, σ = CP.generate_ζ( - rng, g_flux, f, ϕ, xMg_batch, map(get_concrete, interpreters); + rng, g_flux, ϕ, xMg_batch, map(get_concrete, interpreters); n_MC = 8) @test ζ isa CuMatrix @test eltype(ζ) == FT gr = Zygote.gradient( ϕ -> sum(CP.generate_ζ( - rng, g_flux, f, ϕ, xMg_batch, map(get_concrete, interpreters); + rng, g_flux, ϕ, xMg_batch, map(get_concrete, interpreters); n_MC = 8)[1]), ϕ) @test gr[1] isa CuVector diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 392e76b..e60e188 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -21,81 +21,83 @@ scenario = (:default,) (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) -@testset "test_sample_zeta" begin - (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o - ) = gen_hybridcase_synthetic(case, rng; scenario) +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +) = gen_hybridcase_synthetic(case, rng; scenario) - # n_site = 2 - # n_θP, n_θM = length(θ_true.θP), length(θ_true.θM) - # σ_θM = θ_true.θM .* 0.1 # 10% around expected - # θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM +# n_site = 2 +# n_θP, n_θM = length(θ_true.θP), length(θ_true.θM) +# σ_θM = θ_true.θM .* 0.1 # 10% around expected +# θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM - # set to 0.02 rather than zero for debugging non-zero correlations - ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02 - ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02 +# set to 0.02 rather than zero for debugging non-zero correlations +ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02 +ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02 - ϕunc = CA.ComponentVector(; - logσ2_logP=fill(-10.0, n_θP), - coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), - ρsP, - ρsM) +ϕunc = CA.ComponentVector(; + logσ2_logP=fill(-10.0, n_θP), + coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), + ρsP, + ρsM) - θ_true = θ = CA.ComponentVector(; - P=θP_true, - Ms=θMs_true) - transPMs = elementwise(exp) # all parameters on LogNormal scale - ζ_true = inverse(transPMs)(θ_true) - ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc)) - ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc)) +θ_true = θ = CA.ComponentVector(; + P=θP_true, + Ms=θMs_true) +transPMs = elementwise(exp) # all parameters on LogNormal scale +ζ_true = inverse(transPMs)(θ_true) +ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc)) +ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc)) - interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) +interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) - n_MC = 3 - @testset "sample_ζ_norm0 cpu" begin - ϕ = CA.getdata(ϕ_cpu) - ϕc = interpreters.pmu(ϕ) - ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) +n_MC = 3 +@testset "sample_ζ_norm0 cpu" begin + ϕ = CA.getdata(ϕ_cpu) + ϕc = interpreters.pmu(ϕ) + cor_starts=(P=(1,),M=(1,)) + ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_starts) + @test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC) + gr = Zygote.gradient(ϕc -> sum(CP.sample_ζ_norm0( + rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_starts)[1]), ϕc)[1] + @test length(gr) == length(ϕ) +end +# + +if CUDA.functional() + @testset "sample_ζ_norm0 gpu" begin + ϕ = CuArray(CA.getdata(ϕ_cpu)) + cor_starts=(P=(1,),M=(1,)) + #tmp = ϕ[1:6] + #vec2uutri(tmp) + ϕc = interpreters.pmu(ϕ); + @test CA.getdata(ϕc) isa GPUArraysCore.AbstractGPUArray + #ζP, ζMs, ϕunc = ϕc.P, ϕc.Ms, ϕc.unc + #urand = CUDA.randn(length(ϕc.P) + length(ϕc.Ms), n_MC) |> gpu + #include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss + #ζ_resid, σ = sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) + #Zygote.gradient(ϕc -> sum(sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1]; + ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_starts) + @test ζ_resid isa GPUArraysCore.AbstractGPUArray @test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC) - gr = Zygote.gradient(ϕc -> sum(CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc)[1]), ϕc)[1] + gr = Zygote.gradient( + ϕc -> sum(CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC, cor_starts)[1]), ϕc)[1]; @test length(gr) == length(ϕ) + int_unc = ComponentArrayInterpreter(ϕc.unc) + gr2 = Zygote.gradient( + ϕc -> sum(CP.sample_ζ_norm0(rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), + CA.getdata(ϕc.unc), int_unc; n_MC, cor_starts)[1]), + ϕc)[1]; end - # +end - if CUDA.functional() - @testset "sample_ζ_norm0 gpu" begin - ϕ = CuArray(CA.getdata(ϕ_cpu)) - #tmp = ϕ[1:6] - #vec2uutri(tmp) - ϕc = interpreters.pmu(ϕ) - @test CA.getdata(ϕc) isa GPUArraysCore.AbstractGPUArray - #ζP, ζMs, ϕunc = ϕc.P, ϕc.Ms, ϕc.unc - #urand = CUDA.randn(length(ϕc.P) + length(ϕc.Ms), n_MC) |> gpu - #include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss - #ζ_resid, σ = sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) - #Zygote.gradient(ϕc -> sum(sample_ζ_norm0(urand, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1]; - ζ_resid, σ = CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC) - @test ζ_resid isa GPUArraysCore.AbstractGPUArray - @test size(ζ_resid) == (length(ϕc.P) + n_θM * n_site, n_MC) - gr = Zygote.gradient( - ϕc -> sum(CP.sample_ζ_norm0(rng, ϕc.P, ϕc.Ms, ϕc.unc; n_MC)[1]), ϕc)[1] - @test length(gr) == length(ϕ) - int_unc = ComponentArrayInterpreter(ϕc.unc) - gr2 = Zygote.gradient( - ϕc -> sum(CP.sample_ζ_norm0(rng, CA.getdata(ϕc.P), CA.getdata(ϕc.Ms), - CA.getdata(ϕc.unc), int_unc; n_MC)[1]), - ϕc)[1] - end - end +# @testset "generate_ζ" begin +# ϕ = CA.getdata(ϕ_cpu) +# n_sample_pred = 200 +# intm_PMs_gen = ComponentArrayInterpreter(CA.ComponentVector(; θP_true, +# θMs=CA.ComponentMatrix( +# zeros(n_θM, n_site), first(CA.getaxes(θMs_true)), CA.Axis(i=1:n_sample_pred)))) +# int_μP_ϕg_unc=ComponentArrayInterpreter(ϕ_true) +# interpreters = (; PMs = intm_PMs_gen, μP_ϕg_unc = int_μP_ϕg_unc ) +# ζs, _ = CP.generate_ζ(rng, g, ϕ, xM, interpreters; n_MC=n_sample_pred) - # @testset "generate_ζ" begin - # ϕ = CA.getdata(ϕ_cpu) - # n_sample_pred = 200 - # intm_PMs_gen = ComponentArrayInterpreter(CA.ComponentVector(; θP_true, - # θMs=CA.ComponentMatrix( - # zeros(n_θM, n_site), first(CA.getaxes(θMs_true)), CA.Axis(i=1:n_sample_pred)))) - # int_μP_ϕg_unc=ComponentArrayInterpreter(ϕ_true) - # interpreters = (; PMs = intm_PMs_gen, μP_ϕg_unc = int_μP_ϕg_unc ) - # ζs, _ = CP.generate_ζ(rng, g, f, ϕ, xM, interpreters; n_MC=n_sample_pred) +# end; - # end; -end; From fa44a26a76b3ef2c65d1050f6d286d1340aa3b38 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 24 Jan 2025 18:01:40 +0100 Subject: [PATCH 2/7] provide logdensity_obs computing function with case and uncertainty with dataloader --- dev/doubleMM.jl | 100 ++++++---------- ext/HybridVariationalInferenceFluxExt.jl | 22 ++-- ext/HybridVariationalInferenceLuxExt.jl | 17 +-- ...bridVariationalInferenceSimpleChainsExt.jl | 18 +-- src/DoubleMM/f_doubleMM.jl | 15 ++- src/HybridProblem.jl | 23 +++- src/HybridVariationalInference.jl | 6 +- src/ModelApplicator.jl | 15 ++- src/elbo.jl | 13 ++- src/gf.jl | 29 ++--- src/hybrid_case.jl | 86 ++++++++------ src/logden_normal.jl | 31 ++--- test/runtests.jl | 4 +- test/test_Flux.jl | 4 +- test/test_HybridProblem.jl | 110 ++++++++++-------- test/test_Lux.jl | 4 +- test/test_SimpleChains.jl | 2 +- test/test_doubleMM.jl | 58 +++++---- test/test_elbo.jl | 99 ++++------------ 19 files changed, 325 insertions(+), 331 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 75b43bf..4a50a94 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -26,32 +26,32 @@ par_templates = get_hybridcase_par_templates(case; scenario) (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridcase_synthetic(case, rng; scenario); #----- fit g to θMs_true g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); +(; transP, transM) = get_hybridcase_transforms(case; scenario) -function loss_g(ϕg, x, g) +function loss_g(ϕg, x, g, transM) ζMs = g(x, ϕg) # predict the log of the parameters - θMs = exp.(ζMs) + θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column loss = sum(abs2, θMs .- θMs_true) return loss, θMs end -loss_g(ϕg0, xM, g) -Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0); +loss_g(ϕg0, xM, g, transM) -optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1], +optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1], Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, ϕg0); res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 800); ϕg_opt1 = res.u; -loss_g(ϕg_opt1, xM, g) -scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) -@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9 +l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM) +scatterplot(vec(θMs_true), vec(θMs_pred)) f = get_hybridcase_PBmodel(case; scenario) +py = get_hybridcase_neg_logden_obs(case; scenario) #----------- fit g and θP to y_o () -> begin @@ -82,13 +82,12 @@ f = get_hybridcase_PBmodel(case; scenario) end #---------- HVI -logσ2y = 2 .* log.(σ_o) n_MC = 3 -transP = elementwise(exp) -transM = Stacked(elementwise(identity), elementwise(exp)) +(; transP, transM) = get_hybridcase_transforms(case; scenario) +FT = get_hybridcase_float_type(case; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP = asℝ₊, transM = asℝ₊); + θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM); ϕ_true = ϕ () -> begin @@ -149,49 +148,21 @@ transM = Stacked(elementwise(identity), elementwise(exp)) ϕ_true = inverse_ca(trans_gu, ϕt_true) end -ϕ_ini0 = ζ = vcat(ϕ_true[:μP] .* 0.0, ϕg0, ϕ_true[[:unc]]); # scratch +ϕ_ini0 = ζ = reduce( + vcat, ( + ϕ_true[[:μP]] .* FT(0.001), CA.ComponentVector(ϕg = ϕg0), ϕ_true[[:unc]])) # scratch # -# true values -ϕ_ini = ζ = vcat(ϕ_true[[:μP, :ϕg]] .* 1.2, ϕ_true[[:unc]]); # slight disturbance +ϕ_ini = ζ = reduce( + vcat, ( + ϕ_true[[:μP]] .- FT(0.1), ϕ_true[[:ϕg]] .* FT(1.1), ϕ_true[[:unc]])) # slight disturbance # hardcoded from HMC inversion ϕ_ini.unc.coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951] ϕ_ini.unc.logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893) mean_σ_o_MC = 0.006042 -# test cost function and gradient -() -> begin - neg_elbo_transnorm_gf(rng, g, f, ϕ_true, y_o[:, 1:n_batch], xM[:, 1:n_batch], - transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y) - Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf( - rng, g, f, ϕ, y_o[:, 1:n_batch], xM[:, 1:n_batch], - transPMs_batch, interpreters; n_MC = 8, logσ2y), - CA.getdata(ϕ_true)) -end - -# optimize using SimpleChains -() -> begin - train_loader = MLUtils.DataLoader((xM, y_o), batchsize = n_batch) - - optf = Optimization.OptimizationFunction( - (ϕ, data) -> begin - xM, y_o = data - neg_elbo_transnorm_gf( - rng, g, f, ϕ, y_o, xM, transPMs_batch, - map(get_concrete, interpreters_g); n_MC = 5, logσ2y) - end, - Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader) - res = Optimization.solve( - optprob, Optimisers.Adam(0.02), callback = callback_loss(50), maxiters = 800) - #optprob = Optimization.OptimizationProblem(optf, ϕ_ini0); - #res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400); -end - -ϕ = ϕ_ini |> Flux.gpu; +ϕ = CA.getdata(ϕ_ini) |> Flux.gpu; xM_gpu = xM |> Flux.gpu; -g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario); +g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario); # otpimize using LUX () -> begin @@ -216,27 +187,25 @@ g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario g_flux = g_luxs end -function fcost(ϕ, xM, y_o) - neg_elbo_transnorm_gf(rng, g_flux, f, CA.getdata(ϕ), y_o, - xM, transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y = logσ2y) +function fcost(ϕ, xM, y_o, y_unc) + neg_elbo_transnorm_gf(rng, g_flux, f, py, CA.getdata(ϕ), y_o, y_unc, + xM, xP, transPMs_batch, map(get_concrete, interpreters); + n_MC = 8) end -fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch]) +fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch]) #Zygote.gradient(fcost, ϕ) |> cpu; gr = Zygote.gradient(fcost, - CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), CA.getdata(y_o[:, 1:n_batch])); -gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ)...) + CA.getdata(ϕ), CA.getdata(xM_gpu[:, 1:n_batch]), + CA.getdata(y_o[:, 1:n_batch]), CA.getdata(y_unc[:, 1:n_batch])); +gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...) -train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch) -train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux)) +train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) +#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux)) optf = Optimization.OptimizationFunction( (ϕ, data) -> begin - xM, y_o = data - fcost(ϕ, xM, y_o) - # neg_elbo_transnorm_gf( - # rng, g_flux, f, ϕ, y_o, xM, transPMs_batch, - # map(get_concrete, interpreters); n_MC = 5, logσ2y) + xM, xP, y_o, y_unc = data + fcost(ϕ, xM, y_o, y_unc) end, Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem( @@ -256,7 +225,7 @@ end ζMs_VI = g_flux(xM_gpu, ζ_VIc.ϕg |> Flux.gpu) |> Flux.cpu ϕunc_VI = interpreters.unc(ζ_VIc.unc) -hcat(θP_true, exp.(ζ_VIc.μP)) +hcat(log.(θP_true), ϕ_ini.μP, ζ_VIc.μP) plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI))) #lineplot!(plt, 0.0, 1.1, identity) # @@ -266,11 +235,12 @@ hcat(ϕ_ini.unc, ϕunc_VI) # need to compare to MC sample # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 200 -y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, interpreters; +y_pred = predict_gf(rng, g_flux, f, res.u, xM_gpu, xP, interpreters; get_transPMs, get_ca_int_PMs, n_sample_pred); size(y_pred) # n_obs x n_site, n_sample_pred σ_o_post = dropdims(std(y_pred; dims = 3), dims = 3); +σ_o = exp.(y_unc[:,1] / 2) #describe(σ_o_post) hcat(σ_o, fill(mean_σ_o_MC, length(σ_o)), diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index 1d639bb..b1ad59e 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -3,12 +3,14 @@ module HybridVariationalInferenceFluxExt using HybridVariationalInference, Flux using HybridVariationalInference: HybridVariationalInference as HVI using ComponentArrays: ComponentArrays as CA +using Random struct FluxApplicator{RT} <: AbstractModelApplicator rebuild::RT end -function HVI.construct_FluxApplicator(m::Chain) +function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type::DataType) + # TODO: care fore rng and float_type ϕ, rebuild = destructure(m) FluxApplicator(rebuild), ϕ end @@ -26,17 +28,17 @@ function __init__() HVI.set_default_GPUHandler(FluxGPUDataHandler()) end -function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain, - args...; kwargs...) - # constructor with Flux.Chain - g, ϕg = construct_FluxApplicator(g_chain) - HybridProblem(θP, θM, g, ϕg, args...; kwargs...) -end +# function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Flux.Chain, +# args...; kwargs...) +# # constructor with Flux.Chain +# g, ϕg = construct_FluxApplicator(g_chain) +# HybridProblem(θP, θM, g, ϕg, args...; kwargs...) +# end -function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; +function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; scenario::NTuple = ()) (; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) - FloatType = get_hybridcase_FloatType(case; scenario) + float_type = get_hybridcase_float_type(case; scenario) n_out = n_θM is_using_dropout = :use_dropout ∈ scenario is_using_dropout && error("dropout scenario not supported with Flux yet.") @@ -47,7 +49,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{ # dense layer without bias that maps to n outputs and `identity` activation Flux.Dense(n_covar * 4 => n_out, identity, bias = false) ) - construct_FluxApplicator(g_chain) + construct_ChainsApplicator(rng, g_chain, float_type) end diff --git a/ext/HybridVariationalInferenceLuxExt.jl b/ext/HybridVariationalInferenceLuxExt.jl index bfcf6cb..0b54fe5 100644 --- a/ext/HybridVariationalInferenceLuxExt.jl +++ b/ext/HybridVariationalInferenceLuxExt.jl @@ -10,8 +10,8 @@ struct LuxApplicator{MT, IT} <: AbstractModelApplicator int_ϕ::IT end -function HVI.construct_LuxApplicator(m::Chain, float_type=Float32; device = gpu_device()) - ps, st = Lux.setup(Random.default_rng(), m) +function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type=Float32; device = gpu_device()) + ps, st = Lux.setup(rng, m) ps_ca = float_type.(CA.ComponentArray(ps)) st = st |> device stateful_layer = StatefulLuxLayer{true}(m, nothing, st) @@ -25,11 +25,12 @@ function HVI.apply_model(app::LuxApplicator, x, ϕ) app.stateful_layer(x, ϕc) end -function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain, - args...; device = gpu_device(), kwargs...) - # constructor with SimpleChain - g, ϕg = construct_LuxApplicator(g_chain, eltype(θM); device) - HybridProblem(θP, θM, g, ϕg, args...; kwargs...) -end +# function HVI.HybridProblem(rng::AbstractRNG, +# θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::Chain, +# args...; device = gpu_device(), kwargs...) +# # constructor with SimpleChain +# g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM); device) +# HybridProblem(θP, θM, g, ϕg, args...; kwargs...) +# end end # module diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index f95caa9..b4d0357 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -4,6 +4,7 @@ using HybridVariationalInference, SimpleChains using HybridVariationalInference: HybridVariationalInference as HVI using StatsFuns: logistic using ComponentArrays: ComponentArrays as CA +using Random @@ -11,24 +12,17 @@ struct SimpleChainsApplicator{MT} <: AbstractModelApplicator m::MT end -function HVI.construct_SimpleChainsApplicator(m::SimpleChain, FloatType=Float32) - ϕ = SimpleChains.init_params(m, FloatType); +function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::SimpleChain, FloatType=Float32) + ϕ = SimpleChains.init_params(m, FloatType; rng); SimpleChainsApplicator(m), ϕ end HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) -function HVI.HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, g_chain::SimpleChain, - args...; kwargs...) - # constructor with SimpleChain - g, ϕg = construct_SimpleChainsApplicator(g_chain) - HybridProblem(θP, θM, g, ϕg, args...; kwargs...) -end - -function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; +function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; scenario::NTuple=()) (;n_covar, n_θM) = get_hybridcase_sizes(case; scenario) - FloatType = get_hybridcase_FloatType(case; scenario) + FloatType = get_hybridcase_float_type(case; scenario) n_out = n_θM is_using_dropout = :use_dropout ∈ scenario g_chain = if is_using_dropout @@ -52,7 +46,7 @@ function HVI.get_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{ TurboDense{false}(identity, n_out) ) end - construct_SimpleChainsApplicator(g_chain, FloatType) + construct_ChainsApplicator(rng, g_chain, FloatType) end end # module diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index c69680d..6522810 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -22,10 +22,14 @@ function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ()) (; θP, θM) end -function HVI.get_hybridcase_transforms(::AbstractHybridCase; scenario::NTuple = ()) +function HVI.get_hybridcase_transforms(::DoubleMMCase; scenario::NTuple = ()) (; transP, transM) end +function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ()) + neg_logden_indep_normal +end + function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) n_covar_pc = 2 n_covar = n_covar_pc + 3 # linear dependent @@ -46,7 +50,7 @@ function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) end end -# function HVI.get_hybridcase_FloatType(::DoubleMMCase; scenario) +# function HVI.get_hybridcase_float_type(::DoubleMMCase; scenario) # return Float32 # end @@ -58,7 +62,7 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; n_covar_pc = 2 n_site = 200 (; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) - FloatType = get_hybridcase_FloatType(case; scenario) + FloatType = get_hybridcase_float_type(case; scenario) xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM; rhodec = 8, is_using_dropout = false) int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) @@ -68,6 +72,7 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; xP = fill((;S1=xP_S1, S2=xP_S2), n_site) y_global_true, y_true = f(θP, θMs_true, xP) σ_o = FloatType(0.01) + logσ2_o = FloatType(2) .* log.(σ_o) #σ_o = 0.002 y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o @@ -81,9 +86,11 @@ function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; y_true, y_global_o, y_o, - σ_o = fill(σ_o, size(y_true,1)), + y_unc = fill(logσ2_o, size(y_o)), ) end + + diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 1dcc197..0aa44bc 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -4,6 +4,7 @@ struct HybridProblem <: AbstractHybridCase f g ϕg + py transP transM cor_starts # = (P=(1,),M=(1,)) @@ -13,21 +14,35 @@ struct HybridProblem <: AbstractHybridCase # inner constructor to constrain the types function HybridProblem( θP::CA.ComponentVector, θM::CA.ComponentVector, - g::AbstractModelApplicator, ϕg, + g::AbstractModelApplicator, ϕg::AbstractVector, f::Function, + py::Function, transM::Union{Function, Bijectors.Transform}, transP::Union{Function, Bijectors.Transform}, n_covar::Integer, n_batch::Integer, train_loader::DataLoader, - cor_starts = (P=(1,), M=(1,))) - new(θP, θM, f, g, ϕg, transM, transP, cor_starts, n_covar, n_batch, train_loader) + cor_starts::NamedTuple = (P=(1,), M=(1,))) + new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, n_covar, n_batch, train_loader) end end +function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, + # note no ϕg argument and g_chain unconstrained + g_chain, f::Function, + args...; rng = Random.default_rng(), kwargs...) + # dispatches on type of g_chain + g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM)) + HybridProblem(θP, θM, g, ϕg, f, args...; kwargs...) +end + function get_hybridcase_par_templates(prob::HybridProblem; scenario::NTuple = ()) (; θP = prob.θP, θM = prob.θM) end +function get_hybridcase_neg_logden_obs(prob::HybridProblem; scenario::NTuple = ()) + prob.py +end + function get_hybridcase_transforms(prob::HybridProblem; scenario::NTuple = ()) (; transP = prob.transP, transM = prob.transM) end @@ -56,7 +71,7 @@ function get_hybridcase_cor_starts(prob::HybridProblem; scenario = ()) prob.cor_starts end -# function get_hybridcase_FloatType(prob::HybridProblem; scenario::NTuple = ()) +# function get_hybridcase_float_type(prob::HybridProblem; scenario::NTuple = ()) # eltype(prob.θM) # end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index e439de5..df9bfcf 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -16,15 +16,15 @@ using MLUtils # dataloader export ComponentArrayInterpreter, flatten1, get_concrete include("ComponentArrayInterpreter.jl") -export AbstractModelApplicator, construct_SimpleChainsApplicator, construct_FluxApplicator, - construct_LuxApplicator +export AbstractModelApplicator, construct_ChainsApplicator include("ModelApplicator.jl") export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler include("GPUDataHandler.jl") -export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic, +export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_float_type, gen_hybridcase_synthetic, get_hybridcase_par_templates, get_hybridcase_transforms, get_hybridcase_train_dataloader, + get_hybridcase_neg_logden_obs, gen_cov_pred include("hybrid_case.jl") diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 1ada30e..fb928b6 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -22,8 +22,17 @@ function apply_model end (app::AbstractModelApplicator)(x, ϕ) = apply_model(app, x, ϕ) -function construct_SimpleChainsApplicator end -function construct_FluxApplicator end -function construct_LuxApplicator end +""" + construct_ChainsApplicator([rng::AbstractRNG,] chain, float_type) +""" +function construct_ChainsApplicator end + +function construct_ChainsApplicator(chain, float_type::DataType; kwargs...) + construct_ChainsApplicator(Random.default_rng(), chain, float_type; kwargs...) +end + +# function construct_SimpleChainsApplicator end +# function construct_FluxApplicator end +# function construct_LuxApplicator end diff --git a/src/elbo.jl b/src/elbo.jl index 4e83a1e..1be7eab 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -8,20 +8,22 @@ expected value of the likelihood of observations. - rng: random number generator (ignored on CUDA, if ϕ is a AbstractGPUArray) - g: machine learning model - f: mechanistic model +- py: negative log-likelihood of observations given predictions: + `function(y_ob, y_pred, y_unc)` - ϕ: flat vector of parameters including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc), interpreted by interpreters.μP_ϕg_unc and interpreters.PMs - y_ob: matrix of observations (n_obs x n_site_batch) +- y_unc: observation uncertainty provided to py (same size as y_ob) - xM: matrix of covariates (n_cov x n_site_batch) - xP: model drivers, iterable of (n_site_batch) - transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params - n_MC: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. -- logσ2y: observation uncertainty (log of the variance) """ -function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::AbstractMatrix, - xP, transPMs, interpreters::NamedTuple; - n_MC=3, logσ2y, gpu_data_handler = get_default_GPUHandler(), +function neg_elbo_transnorm_gf(rng, g, f, py, ϕ::AbstractVector, y_ob, y_unc, + xM::AbstractMatrix, xP, transPMs, interpreters::NamedTuple; + n_MC=3, gpu_data_handler = get_default_GPUHandler(), cor_starts=(P=(1,),M=(1,)) ) ζs, σ = generate_ζ(rng, g, ϕ, xM, interpreters; n_MC, cor_starts) @@ -29,7 +31,8 @@ function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, xM::Abstract #ζi = first(eachcol(ζs_cpu)) nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi y_pred_i, logjac = predict_y(ζi, xP, f, transPMs, interpreters.PMs) - nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y) + #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) + nLy1 = py(y_ob, y_pred_i, y_unc) nLy1 - logjac end) / n_MC #sum_log_σ = sum(log.(σ)) diff --git a/src/gf.jl b/src/gf.jl index c86098e..e23f10d 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -17,11 +17,14 @@ end #applyf(f_double, θMs_true, stack(Iterators.repeated(CA.getdata(θP_true), size(θMs_true,2)))) """ -composition f ∘ g: mechanistic model after machine learning parameter prediction +composition f ∘ transM ∘ g: mechanistic model after machine learning parameter prediction """ -function gf(g, f, x, xP, ϕg, θP) - ζMs = g(x, ϕg) # predict the log of the parameters - θMs = exp.(ζMs) +function gf(g, transM, f, xM, xP, ϕg, θP) + # @show first(xM,5) + # @show first(ϕg,5) + ζMs = g(xM, ϕg) # predict the log of the parameters + # @show first(ζMs,5) + θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column y_pred_global, y_pred = f(θP, θMs, xP) return y_pred_global, y_pred, θMs end @@ -30,23 +33,23 @@ end Create a loss function for parameter vector p, given - g(x, ϕ): machine learning model - f(θMs, θP): mechanistic model -- x_o: matrix of covariates, sites in columns +- xM: matrix of covariates, sites in columns - y_o: matrix of observations, sites in columns - int_ϕθP: interpreter attachin axis with compponents ϕg and pc.θP """ -function get_loss_gf(g, f, y_o_global, int_ϕθP::AbstractComponentArrayInterpreter) - let g = g, f = f, int_ϕθP = int_ϕθP - function loss_gf(p, x_o, xP, y_o) +function get_loss_gf(g, transM, f, y_o_global, int_ϕθP::AbstractComponentArrayInterpreter) + let g = g, transM = transM, f = f, int_ϕθP = int_ϕθP + function loss_gf(p, xM, xP, y_o, y_unc) + σ = exp.(y_unc ./ 2) pc = int_ϕθP(p) - y_pred_global, y_pred, θMs = gf(g, f, x_o, xP, pc.ϕg, pc.θP) - #Main.@infiltrate_main - loss = sum(abs2, y_pred .- y_o) + sum(abs2, y_pred_global .- y_o_global) + y_pred_global, y_pred, θMs = gf(g, transM, f, xM, xP, pc.ϕg, pc.θP) + loss = sum(abs2, (y_pred .- y_o) ./ σ) + sum(abs2, y_pred_global .- y_o_global) return loss, y_pred_global, y_pred, θMs end end end () -> begin - loss_gf(p, x_o, y_o) - Zygote.gradient(x -> loss_gf(x, x_o, y_o)[1], p) + loss_gf(p, xM, y_o) + Zygote.gradient(x -> loss_gf(x, xM, y_o)[1], p) end diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index 02a3ec4..bac0177 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -3,49 +3,23 @@ Type to dispatch constructing data and network structures for different cases of hybrid problem setups For a specific case, provide functions that specify details +- `get_hybridcase_MLapplicator` +- `get_hybridcase_PBmodel` +- `get_hybridcase_neg_logden_obs` - `get_hybridcase_par_templates` - `get_hybridcase_transforms` - `get_hybridcase_sizes` -- `get_hybridcase_MLapplicator` -- `get_hybridcase_PBmodel` - `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`) optionally - `gen_hybridcase_synthetic` -- `get_hybridcase_FloatType` (defaults to `eltype(θM)`) +- `get_hybridcase_float_type` (defaults to `eltype(θM)`) - `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`) """ abstract type AbstractHybridCase end; -""" - get_hybridcase_par_templates(::AbstractHybridCase; scenario) -Provide tuple of templates of ComponentVectors `θP` and `θM`. """ -function get_hybridcase_par_templates end - - -""" - get_hybridcase_transforms(::AbstractHybridCase; scenario) - -Return a NamedTupe of -- `transP`: Bijectors.Transform for the global PBM parameters, θP -- `transM`: Bijectors.Transform for the single-site PBM parameters, θM -""" -function get_hybridcase_transforms end - -""" - get_hybridcase_par_templates(::AbstractHybridCase; scenario) - -Provide a NamedTuple of number of -- n_covar: covariates xM -- n_site: all sites in the data -- n_batch: sites in one minibatch during fitting -- n_θM, n_θP: entries in parameter vectors -""" -function get_hybridcase_sizes end - -""" - get_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=()) + get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase, MLEngine; scenario=()) Construct the machine learning model fro given problem case and ML-Framework and scenario. @@ -59,6 +33,10 @@ returns a Tuple of """ function get_hybridcase_MLapplicator end +function get_hybridcase_MLapplicator(case::AbstractHybridCase, MLEngine; scenario=()) + get_hybridcase_MLapplicator(Random.default_rng(), case, MLEngine; scenario) +end + """ get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=()) @@ -75,6 +53,43 @@ returns a tuple of predictions with components """ function get_hybridcase_PBmodel end +""" + get_hybridcase_neg_logden_obs(::AbstractHybridCase; scenario) + +Provide a `function(y_obs, ypred) -> Real` that computes the negative logdensity +of the observations, given the predictions. +""" +function get_hybridcase_neg_logden_obs end + + +""" + get_hybridcase_par_templates(::AbstractHybridCase; scenario) + +Provide tuple of templates of ComponentVectors `θP` and `θM`. +""" +function get_hybridcase_par_templates end + + +""" + get_hybridcase_transforms(::AbstractHybridCase; scenario) + +Return a NamedTupe of +- `transP`: Bijectors.Transform for the global PBM parameters, θP +- `transM`: Bijectors.Transform for the single-site PBM parameters, θM +""" +function get_hybridcase_transforms end + +""" + get_hybridcase_par_templates(::AbstractHybridCase; scenario) + +Provide a NamedTuple of number of +- n_covar: covariates xM +- n_site: all sites in the data +- n_batch: sites in one minibatch during fitting +- n_θM, n_θP: entries in parameter vectors +""" +function get_hybridcase_sizes end + """ gen_hybridcase_synthetic(::AbstractHybridCase, rng; scenario) @@ -90,11 +105,11 @@ Setup synthetic data, a NamedTuple of function gen_hybridcase_synthetic end """ - get_hybridcase_FloatType(::AbstractHybridCase; scenario) + get_hybridcase_float_type(::AbstractHybridCase; scenario) Determine the FloatType for given Case and scenario, defaults to Float32 """ -function get_hybridcase_FloatType(case::AbstractHybridCase; scenario=()) +function get_hybridcase_float_type(case::AbstractHybridCase; scenario=()) return eltype(get_hybridcase_par_templates(case; scenario).θM) end @@ -105,13 +120,14 @@ Return a DataLoader that provides a tuple of - `xM`: matrix of covariates, with one column per site - `xP`: Iterator of process-model drivers, with one element per site - `y_o`: matrix of observations with added noise, with one column per site +- `y_unc`: matrix `sizeof(y_o)` of uncertainty information """ function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::AbstractRNG; scenario = ()) - (; xM, xP, y_o) = gen_hybridcase_synthetic(case, rng; scenario) + (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(case, rng; scenario) (; n_batch) = get_hybridcase_sizes(case; scenario) xM_gpu = :use_flux ∈ scenario ? CuArray(xM) : xM - train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o), batchsize = n_batch) + train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) return(train_loader) end diff --git a/src/logden_normal.jl b/src/logden_normal.jl index b4d4dfe..88fb2ce 100644 --- a/src/logden_normal.jl +++ b/src/logden_normal.jl @@ -13,7 +13,7 @@ a low uncertainty estimate and means closer to the observations to help an initial fit. The obtained parameters then can be used as starting values for a the proper fit with `σfac=1.0`. """ -function neg_logden_indep_normal(obs::AbstractVector, μ::AbstractVector, logσ2::AbstractVector; σfac=1.0) +function neg_logden_indep_normal(obs::AbstractArray, μ::AbstractArray, logσ2::AbstractArray; σfac=1.0) # log of independent Normal distributions # estimate independent uncertainty of each θM, rather than full covariance #nlogL = sum(σfac .* log.(σs) .+ 1 / 2 .* abs2.((obs .- μ) ./ σs)) @@ -26,21 +26,22 @@ function neg_logden_indep_normal(obs::AbstractVector, μ::AbstractVector, logσ2 nlogL = sum(σfac .* logσ2 .+ abs2.(obs .- μ) .* exp.(.-logσ2)) / 2 return (nlogL) end -function neg_logden_indep_normal(obss::AbstractMatrix, preds::AbstractMatrix, logσ2::AbstractVector; kwargs...) - nlogLs = map(eachcol(obss), eachcol(preds)) do obs, μ - neg_logden_indep_normal(obs, μ, logσ2; kwargs...) - end - nlogL = sum(nlogLs) - return nlogL -end +# function neg_logden_indep_normal(obss::AbstractMatrix, preds::AbstractMatrix, logσ2::AbstractVector; kwargs...) +# nlogLs = map(eachcol(obss), eachcol(preds)) do obs, μ +# neg_logden_indep_normal(obs, μ, logσ2; kwargs...) +# end +# nlogL = sum(nlogLs) +# return nlogL +# end + +# function neg_logden_indep_normal(obss::AbstractMatrix, preds::AbstractMatrix, logσ2s::AbstractMatrix; kwargs...) +# nlogLs = map(eachcol(obss), eachcol(preds), eachcol(logσ2s)) do obs, μ, logσ2 +# neg_logden_indep_normal(obs, μ, logσ2; kwargs...) +# end +# nlogL = sum(nlogLs) +# return nlogL +# end -function neg_logden_indep_normal(obss::AbstractMatrix, preds::AbstractMatrix, logσ2s::AbstractMatrix; kwargs...) - nlogLs = map(eachcol(obss), eachcol(preds), eachcol(logσ2s)) do obs, μ, logσ2 - neg_logden_indep_normal(obs, μ, logσ2; kwargs...) - end - nlogL = sum(nlogLs) - return nlogL -end entropy_MvNormal(K, logdetΣ) = (K*(1+log(2π)) + logdetΣ)/2 entropy_MvNormal(Σ) = entropy_MvNormal(size(Σ,1), logdet(Σ)) diff --git a/test/runtests.jl b/test/runtests.jl index 78ec965..7599ade 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,14 +13,14 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml @time @safetestset "test_logden_normal" include("test_logden_normal.jl") #@safetestset "test" include("test/test_doubleMM.jl") @time @safetestset "test_doubleMM" include("test_doubleMM.jl") - #@safetestset "test" include("test/test_HybridProblem.jl") - @time @safetestset "test_HybridProblem" include("test_HybridProblem.jl") #@safetestset "test" include("test/test_cholesky_structure.jl") @time @safetestset "test_cholesky_structure" include("test_cholesky_structure.jl") #@safetestset "test" include("test/test_sample_zeta.jl") @time @safetestset "test_sample_zeta" include("test_sample_zeta.jl") #@safetestset "test" include("test/test_elbo.jl") @time @safetestset "test_elbo" include("test_elbo.jl") + #@safetestset "test" include("test/test_HybridProblem.jl") + @time @safetestset "test_HybridProblem" include("test_HybridProblem.jl") # #@safetestset "test" include("test/test_Flux.jl") @time @safetestset "test_Flux" include("test_Flux.jl") diff --git a/test/test_Flux.jl b/test/test_Flux.jl index 6aa62c3..a9378d5 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -35,9 +35,9 @@ end; Dense(n_covar * 4 => n_covar * 4, tanh), Dense(n_covar * 4 => n_out, identity, bias=false), ) - g, ϕg = construct_FluxApplicator(g_chain |> f64) + g, ϕg = construct_ChainsApplicator(g_chain |> f64, Float64) @test eltype(ϕg) == Float64 - g, ϕg = construct_FluxApplicator(g_chain) + g, ϕg = construct_ChainsApplicator(g_chain, Float32) @test eltype(ϕg) == Float32 n_site = 3 x = rand(Float32, n_covar, n_site) diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 215f0ff..2306d3a 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -12,15 +12,15 @@ import Zygote using OptimizationOptimisers - const MLengine = Val(nameof(SimpleChains)) construct_problem = () -> begin - θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0) - θM = CA.ComponentVector{Float32}(r1 = 0.5, K1 = 0.2) + FT = Float32 + θP = CA.ComponentVector{FT}(r0=0.3, K2=2.0) + θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) transP = elementwise(exp) transM = Stacked(elementwise(identity), elementwise(exp)) - cov_starts = (P=(1,2),M=(1)) # assume r0 independent of K2 + cov_starts = (P=(1, 2), M=(1)) # assume r0 independent of K2 n_covar = 5 n_batch = 10 int_θdoubleMM = get_concrete(ComponentArrayInterpreter( @@ -49,12 +49,14 @@ construct_problem = () -> begin # g, ϕg = construct_SimpleChainsApplicator(g_chain) # rng = StableRNG(111) - (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o -) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng;) - train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) + # dependency on DeoubleMMCase -> take care of changes in covariates + (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc + ) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng) + py = neg_logden_indep_normal + train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize=n_batch) # HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, # g, ϕg, train_loader) - HybridProblem(θP, θM, g_chain, f_doubleMM_with_global, + HybridProblem(θP, θM, g_chain, f_doubleMM_with_global, py, transM, transP, n_covar, n_batch, train_loader, cov_starts) end prob = construct_problem(); @@ -66,18 +68,19 @@ scenario = (:default,) #----------- fit g and θP to y_o g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario) train_loader = get_hybridcase_train_dataloader(prob; scenario) - (xM, xP, y_o) = first(train_loader) + (xM, xP, y_o, y_unc) = first(train_loader) f = get_hybridcase_PBmodel(prob; scenario) par_templates = get_hybridcase_par_templates(prob; scenario) + (;transM, transP) = get_hybridcase_transforms(prob; scenario) int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( - ϕg = 1:length(ϕg0), θP = par_templates.θP)) + ϕg=1:length(ϕg0), θP=par_templates.θP)) p = p0 = vcat(ϕg0, par_templates.θP .* 0.8) # slightly disturb θP_true # Pass the site-data for the batches as separate vectors wrapped in a tuple y_global_o = Float64[] - loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) + loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP) l1 = loss_gf(p0, first(train_loader)...) gr = Zygote.gradient(p -> loss_gf(p, train_loader.data...)[1], p0) @test gr[1] isa Vector @@ -89,61 +92,72 @@ scenario = (:default,) res = Optimization.solve( # optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000); - optprob, Adam(0.02), maxiters = 1000) + optprob, Adam(0.02), maxiters=1000) l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) - @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) + @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol=0.11) end end -() -> begin +using CUDA +import Flux + @testset "neg_elbo_transnorm_gf cpu" begin rng = StableRNG(111) - g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine); + g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine) train_loader = get_hybridcase_train_dataloader(prob) - (xM, xP, y_o) = first(train_loader) - n_batch = size(y_o,2) + (xM, xP, y_o, y_unc) = first(train_loader) + n_batch = size(y_o, 2) f = get_hybridcase_PBmodel(prob) (θP0, θM0) = get_hybridcase_par_templates(prob) (; transP, transM) = get_hybridcase_transforms(prob) + py = get_hybridcase_neg_logden_obs(prob) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP0, θM0, ϕg0, n_batch; transP, transM); + θP0, θM0, ϕg0, n_batch; transP, transM) ϕ_ini = ϕ - - cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o, + + py = get_hybridcase_neg_logden_obs(prob) + + cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o, y_unc, xM, xP, transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y) + n_MC=8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf( - rng, g, f, ϕ, y_o[:, 1:n_batch], - xM[:, 1:n_batch], xP[1:n_batch], - transPMs_batch, interpreters; n_MC = 8, logσ2y), + ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, + xM, xP, transPMs_batch, map(get_concrete, interpreters); + n_MC=8), CA.getdata(ϕ_ini)) @test gr[1] isa Vector -end; - -if CUDA.functional() - @testset "neg_elbo_transnorm_gf gpu" begin - ϕ = CuArray(CA.getdata(ϕ_ini)) - xMg_batch = CuArray(xM[:, 1:n_batch]) - xP_batch = xP[1:n_batch] # used in f which runs on CPU - cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch], - xMg_batch, xP_batch, - transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y) - @test cost isa Float64 - gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf( - rng, g_flux, f, ϕ, y_o[:, 1:n_batch], - xMg_batch, xP_batch, - transPMs_batch, interpreters; n_MC = 8, logσ2y), - ϕ) - @test gr[1] isa CuVector - @test eltype(gr[1]) == FT + + if CUDA.functional() + @testset "neg_elbo_transnorm_gf gpu" begin + g, ϕg0 = begin + n_covar = size(xM, 1) + n_out = length(θM0) + g_chain = Flux.Chain( + # dense layer with bias that maps to 8 outputs and applies `tanh` activation + Flux.Dense(n_covar => n_covar * 4, tanh), + Flux.Dense(n_covar * 4 => n_covar * 4, tanh), + # dense layer without bias that maps to n outputs and `identity` activation + Flux.Dense(n_covar * 4 => n_out, identity, bias=false) + ) + construct_ChainsApplicator(g_chain, eltype(θM0)) + end + ϕ_ini.ϕg = ϕg0 + ϕ = CuArray(CA.getdata(ϕ_ini)) + xMg = CuArray(xM) + cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, + xMg, xP, transPMs_batch, map(get_concrete, interpreters); + n_MC=8) + @test cost isa Float64 + gr = Zygote.gradient( + ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, + xMg, xP, transPMs_batch, map(get_concrete, interpreters); + n_MC=8), + ϕ) + @test gr[1] isa CuVector + @test eltype(gr[1]) == get_hybridcase_float_type(prob) + end end end -end #if false - - diff --git a/test/test_Lux.jl b/test/test_Lux.jl index d80da03..81b9e15 100644 --- a/test/test_Lux.jl +++ b/test/test_Lux.jl @@ -13,9 +13,9 @@ using StatsFuns: logistic Dense(n_covar * 4 => n_covar * 4, tanh), Dense(n_covar * 4 => n_out, logistic, use_bias=false), ); - g, ϕ = construct_LuxApplicator(g_chain, Float64; device = cpu_device()); + g, ϕ = construct_ChainsApplicator(g_chain, Float64; device = cpu_device()); @test eltype(ϕ) == Float64 - g, ϕ = construct_LuxApplicator(g_chain; device = cpu_device()); + g, ϕ = construct_ChainsApplicator(g_chain, Float32; device = cpu_device()); @test eltype(ϕ) == Float32 n_site = 3 x = rand(Float32, n_covar, n_site) diff --git a/test/test_SimpleChains.jl b/test/test_SimpleChains.jl index 29adb37..076d84c 100644 --- a/test/test_SimpleChains.jl +++ b/test/test_SimpleChains.jl @@ -12,7 +12,7 @@ using StatsFuns: logistic TurboDense{true}(tanh, n_covar * 4), TurboDense{false}(logistic, n_out) ) - g, ϕg = construct_SimpleChainsApplicator(g_chain) + g, ϕg = construct_ChainsApplicator(g_chain, Float32) n_site = 3 x = rand(n_covar, n_site) #ϕg = SimpleChains.init_params(g_chain); diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 8e6c5a3..c89df71 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -20,7 +20,7 @@ par_templates = get_hybridcase_par_templates(case; scenario) (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) rng = StableRNG(111) -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridcase_synthetic(case, rng; scenario); @testset "gen_hybridcase_synthetic" begin @@ -36,60 +36,78 @@ rng = StableRNG(111) end @testset "loss_g" begin - g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); - - function loss_g(ϕg, x, g) - ζMs = g(x, ϕg) # predict the log of the parameters - θMs = exp.(ζMs) + g, ϕg0 = get_hybridcase_MLapplicator(rng, case, MLengine; scenario); + (;transP, transM) = get_hybridcase_transforms(case; scenario) + + function loss_g(ϕg, x, g, transM) + # @show first(x,5) + # @show first(ϕg,5) + ζMs = g(x, ϕg) # predict the parameters on unconstrained space + # @show first(ζMs,5) + θMs = reduce(hcat, map(transM, eachcol(ζMs))) # transform each column loss = sum(abs2, θMs .- θMs_true) return loss, θMs end - loss_g(ϕg0, xM, g) - Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0); - - optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1], + loss_g(ϕg0, xM, g, transM) + Zygote.gradient(ϕg -> loss_g(ϕg, xM, g, transM)[1], ϕg0); + # + optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g, transM)[1], Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, ϕg0); #res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600); res = Optimization.solve(optprob, Adam(0.02), maxiters = 600); - + # ϕg_opt1 = res.u; - pred = loss_g(ϕg_opt1, xM, g) - θMs_pred = pred[2] + #first(ϕg_opt1,5) + pred = loss_g(ϕg_opt1, xM, g, transM); + θMs_pred = θMs_pred_1 = pred[2] #scatterplot(vec(θMs_true), vec(θMs_pred)) - @test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 + #@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 + @test cor(θMs_true[:,1], θMs_pred[:,1]) > 0.9 + @test cor(θMs_true[:,2], θMs_pred[:,2]) > 0.9 end @testset "loss_gf" begin - #----------- fit g and θP to y_o (without transformations) + #----------- fit g and θP to y_o (without uncertainty, without transforming θP) g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); + (;transP, transM) = get_hybridcase_transforms(case; scenario) f = get_hybridcase_PBmodel(case; scenario) int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), θP = par_templates.θP)) p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true + #p = p0 = vcat(ϕg_opt1, par_templates.θP); # almost true # Pass the site-data for the batches as separate vectors wrapped in a tuple - #train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) - train_loader = get_hybridcase_train_dataloader(case, rng; scenario) + train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch) + # get_hybridcase_train_dataloader recreates synthetic data differetn θ_true + #train_loader = get_hybridcase_train_dataloader(case, rng; scenario) - loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) - l1 = loss_gf(p0, train_loader.data...)[1] + loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP) + l1 = loss_gf(p0, first(train_loader)...)[1] + (xM_batch, xP_batch, y_o_batch, y_unc_batch) = first(train_loader) + Zygote.gradient(p0 -> loss_gf(p0, xM_batch, xP_batch, y_o_batch, y_unc_batch)[1], p0) optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], Optimization.AutoZygote()) optprob = OptimizationProblem(optf, p0, train_loader) res = Optimization.solve( -# optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000); + #optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000); optprob, Adam(0.02), maxiters = 1000); l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) + #l1, y_pred_global, y_pred, θMs_pred = loss_gf(p0, xM, xP, y_o, y_unc); + θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true)) @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) @test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 + @test cor(θMs_true[:,1], θMs_pred[:,1]) > 0.9 + @test cor(θMs_true[:,2], θMs_pred[:,2]) > 0.9 () -> begin scatterplot(vec(θMs_true), vec(θMs_pred)) + scatterplot(θMs_true[1,:], θMs_pred[1,:]) + scatterplot(θMs_true[2,:], θMs_pred[2,:]) scatterplot(log.(vec(θMs_true)), log.(vec(θMs_pred))) scatterplot(vec(y_pred), vec(y_o)) hcat(par_templates.θP, int_ϕθP(p0).θP, int_ϕθP(res.u).θP) diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 6b15d64..16597f8 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -19,18 +19,21 @@ rng = StableRNG(111) const case = DoubleMM.DoubleMMCase() const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) -FT = get_hybridcase_FloatType(case; scenario) +FT = get_hybridcase_float_type(case; scenario) #θsite_true = get_hybridcase_par_templates(case; scenario) g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); f = get_hybridcase_PBmodel(case; scenario) -(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +n_covar = 5 +n_batch = 10 +n_θM, n_θP = values(map(length, get_hybridcase_par_templates(case; scenario))) -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, σ_o +(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridcase_synthetic(case, rng; scenario); -logσ2y = FT(2) .* log.(σ_o) +py = neg_logden_indep_normal + n_MC = 3 (; transP, transM) = get_hybridcase_transforms(case; scenario) # transP = elementwise(exp) @@ -40,69 +43,6 @@ n_MC = 3 θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM); ϕ_ini = ϕ -() -> begin - # correlation matrices - ρsP = zeros(sum(1:(n_θP - 1))) - ρsM = zeros(sum(1:(n_θM - 1))) - - () -> begin - coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951] - logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893) - #mean_σ_o_MC = 0.006042 - - ϕunc = CA.ComponentVector(; - logσ2_logP = logσ2_logP, - coef_logσ2_logMs = coef_logσ2_logMs, - ρsP, - ρsM) - end - - # for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude - logσ2y = 2 .* log.(σ_o) - ϕunc0 = CA.ComponentVector(; - logσ2_logP = fill(-10.0, n_θP), - coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), - ρsP, - ρsM) - #int_unc = ComponentArrayInterpreter(ϕunc0) - - transPMs_batch = as( - (P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_batch))) - transPMs_allsites = as( - (P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_site))) - - ϕ_true = θ = CA.ComponentVector(; - μP = θP_true, - ϕg = ϕg0, #ϕg_opt, # here start from randomized - unc = ϕunc0) - trans_gu = as( - (μP = as(Array, asℝ₊, n_θP), - ϕg = as(Array, length(ϕg0)), - unc = as(Array, length(ϕunc0)))) - trans_g = as( - (μP = as(Array, asℝ₊, n_θP), - ϕg = as(Array, length(ϕg0)))) - - int_PMs_batch = ComponentArrayInterpreter(CA.ComponentVector(; θP = θP_true, - θMs = CA.ComponentMatrix( - zeros(n_θM, n_batch), first(CA.getaxes(θMs_true)), CA.Axis(i = 1:n_batch)))) - - interpreters = map(get_concrete, - (; - μP_ϕg_unc = ComponentArrayInterpreter(ϕ_true), - PMs = int_PMs_batch, - unc = ComponentArrayInterpreter(ϕunc0) - )) - - ϕg_true_vec = CA.ComponentVector( - TransformVariables.inverse(trans_gu, CP.cv2NamedTuple(ϕ_true))) - ϕcg_true = interpreters.μP_ϕg_unc(ϕg_true_vec) - ϕ_ini = ζ = vcat(ϕcg_true[[:μP, :ϕg]] .* 1.2, ϕcg_true[[:unc]]) - ϕ_ini0 = ζ = vcat(ϕcg_true[:μP] .* 0.0, ϕg0, ϕunc0) -end - @testset "generate_ζ" begin ζ, σ = CP.generate_ζ( rng, g, ϕ_ini, xM[:, 1:n_batch], map(get_concrete, interpreters); @@ -140,15 +80,14 @@ if CUDA.functional() end @testset "neg_elbo_transnorm_gf cpu" begin - cost = neg_elbo_transnorm_gf(rng, g, f, ϕ_ini, y_o[:, 1:n_batch], + cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y) + n_MC = 8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf( - rng, g, f, ϕ, y_o[:, 1:n_batch], - xM[:, 1:n_batch], xP[1:n_batch], - transPMs_batch, interpreters; n_MC = 8, logσ2y), + ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters); + n_MC = 8), CA.getdata(ϕ_ini)) @test gr[1] isa Vector end; @@ -158,16 +97,18 @@ if CUDA.functional() ϕ = CuArray(CA.getdata(ϕ_ini)) xMg_batch = CuArray(xM[:, 1:n_batch]) xP_batch = xP[1:n_batch] # used in f which runs on CPU - cost = neg_elbo_transnorm_gf(rng, g_flux, f, ϕ, y_o[:, 1:n_batch], + cost = neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ, + y_o[:, 1:n_batch], y_unc[:, 1:n_batch], xMg_batch, xP_batch, transPMs_batch, map(get_concrete, interpreters); - n_MC = 8, logσ2y) + n_MC = 8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf( - rng, g_flux, f, ϕ, y_o[:, 1:n_batch], - xMg_batch, xP_batch, - transPMs_batch, interpreters; n_MC = 8, logσ2y), + ϕ -> neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ, + y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + xMg_batch, xP_batch, + transPMs_batch, map(get_concrete, interpreters); + n_MC = 8), ϕ) @test gr[1] isa CuVector @test eltype(gr[1]) == FT From b1f41a6798e664b04a0072a95f43bcb1b0a8b26a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 24 Jan 2025 18:07:41 +0100 Subject: [PATCH 3/7] typo --- src/cholesky.jl | 2 +- src/hybrid_case.jl | 2 +- test/test_doubleMM.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cholesky.jl b/src/cholesky.jl index f0ff914..3e3efb9 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -279,7 +279,7 @@ Transform a parameterization v of a blockdiagonal of upper triangular matrices into the this matrix. `cor_starts` is a NTuple of Integeres specifying the first column of each block. E.g. For a matrix with a 3x3, a 2x2, and another block, -the blocks start at colums (1,4,6). It defaults to a single entire block. +the blocks start at columns (1,4,6). It defaults to a single entire block. """ function transformU_block_cholesky1(v::AbstractVector, cor_starts = (1,)) cor_starts_end = (cor_starts..., length(v)+1) diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index bac0177..a85dcdf 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -142,7 +142,7 @@ Returns a NamedTuple. Subsets ofparameters that are correlated with other but not correlated with parameters of other subranges are specified by indicating the starting position of each subrange. -E.g. if withing global parameter vector `(p1, p2, p3)`, `p1` and `p2` are correlated, +E.g. if within global parameter vector `(p1, p2, p3)`, `p1` and `p2` are correlated, but parameter `p3` is not correlated with them, then the first subrange starts at position 1 and the second subrange starts at position 3. If there is only single block of all ML-predicted parameters being correlated diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index c89df71..6313369 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -80,7 +80,7 @@ end # Pass the site-data for the batches as separate vectors wrapped in a tuple train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch) - # get_hybridcase_train_dataloader recreates synthetic data differetn θ_true + # get_hybridcase_train_dataloader recreates synthetic data different θ_true #train_loader = get_hybridcase_train_dataloader(case, rng; scenario) loss_gf = get_loss_gf(g, transM, f, y_global_o, int_ϕθP) From f4acd16fbced33ed87e2e63c32c04767ef20a787 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 30 Jan 2025 09:46:18 +0100 Subject: [PATCH 4/7] remove get_hybrid_case_sizes rather depend on par_templates and train_dataloader and move rng to first position in train_dataloader and synthetic --- dev/doubleMM.jl | 4 +- ext/HybridVariationalInferenceFluxExt.jl | 6 ++- ...bridVariationalInferenceSimpleChainsExt.jl | 5 ++- src/DoubleMM/f_doubleMM.jl | 25 ++++++----- src/HybridProblem.jl | 19 +++----- src/HybridVariationalInference.jl | 4 +- src/elbo.jl | 33 +++++++------- src/hybrid_case.jl | 44 +++++++++++++------ test/test_HybridProblem.jl | 13 +++--- test/test_doubleMM.jl | 9 ++-- test/test_elbo.jl | 2 +- test/test_sample_zeta.jl | 9 +--- 12 files changed, 91 insertions(+), 82 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 4a50a94..42c4912 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -27,7 +27,7 @@ par_templates = get_hybridcase_par_templates(case; scenario) (; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(case, rng; scenario); +) = gen_hybridcase_synthetic(rng, case; scenario); #----- fit g to θMs_true g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); @@ -62,7 +62,7 @@ py = get_hybridcase_neg_logden_obs(case; scenario) p = p0 = vcat(ϕg0, par_templates.θP .* 0.9) # slightly disturb θP_true # Pass the site-data for the batches as separate vectors wrapped in a tuple - train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) + train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch) loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) l1 = loss_gf(p0, train_loader.data...)[1] diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index b1ad59e..4dacaf3 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -37,9 +37,11 @@ end function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; scenario::NTuple = ()) - (; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) + (;θM) = get_hybridcase_par_templates(case; scenario) + n_out = length(θM) + n_covar = 5 + #(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) float_type = get_hybridcase_float_type(case; scenario) - n_out = n_θM is_using_dropout = :use_dropout ∈ scenario is_using_dropout && error("dropout scenario not supported with Flux yet.") g_chain = Flux.Chain( diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index b4d0357..f201365 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -21,9 +21,10 @@ HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; scenario::NTuple=()) - (;n_covar, n_θM) = get_hybridcase_sizes(case; scenario) + n_covar = get_hybridcase_n_covar(case; scenario) FloatType = get_hybridcase_float_type(case; scenario) - n_out = n_θM + (;θM) = get_hybridcase_par_templates(case; scenario) + n_out = length(θM) is_using_dropout = :use_dropout ∈ scenario g_chain = if is_using_dropout SimpleChain( diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 6522810..7db2300 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -30,16 +30,16 @@ function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = () neg_logden_indep_normal end -function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) - n_covar_pc = 2 - n_covar = n_covar_pc + 3 # linear dependent - #n_site = 10^n_covar_pc - n_batch = 10 - n_θM = length(θM) - n_θP = length(θP) - #(; n_covar, n_site, n_batch, n_θM, n_θP) - (; n_covar, n_batch, n_θM, n_θP) -end +# function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ()) +# n_covar_pc = 2 +# n_covar = n_covar_pc + 3 # linear dependent +# #n_site = 10^n_covar_pc +# n_batch = 10 +# n_θM = length(θM) +# n_θP = length(θP) +# #(; n_covar, n_site, n_batch, n_θM, n_θP) +# (; n_covar, n_batch, n_θM, n_θP) +# end function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ()) #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers @@ -57,11 +57,12 @@ end const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1] const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] -function HVI.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG; +function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase; scenario = ()) n_covar_pc = 2 n_site = 200 - (; n_covar, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) + n_covar = 5 + n_θM = length(θM) FloatType = get_hybridcase_float_type(case; scenario) xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM; rhodec = 8, is_using_dropout = false) diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 0aa44bc..56a5fe4 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -8,8 +8,6 @@ struct HybridProblem <: AbstractHybridCase transP transM cor_starts # = (P=(1,),M=(1,)) - n_covar - n_batch train_loader # inner constructor to constrain the types function HybridProblem( @@ -19,10 +17,9 @@ struct HybridProblem <: AbstractHybridCase py::Function, transM::Union{Function, Bijectors.Transform}, transP::Union{Function, Bijectors.Transform}, - n_covar::Integer, n_batch::Integer, train_loader::DataLoader, cor_starts::NamedTuple = (P=(1,), M=(1,))) - new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, n_covar, n_batch, train_loader) + new(θP, θM, f, g, ϕg, py, transM, transP, cor_starts, train_loader) end end @@ -47,11 +44,11 @@ function get_hybridcase_transforms(prob::HybridProblem; scenario::NTuple = ()) (; transP = prob.transP, transM = prob.transM) end -function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ()) - n_θM = length(prob.θM) - n_θP = length(prob.θP) - (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) -end +# function get_hybridcase_sizes(prob::HybridProblem; scenario::NTuple = ()) +# n_θM = length(prob.θM) +# n_θP = length(prob.θP) +# (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) +# end function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ()) prob.f @@ -61,9 +58,7 @@ function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::N prob.g, prob.ϕg end -function get_hybridcase_train_dataloader( - prob::HybridProblem, rng::AbstractRNG = Random.default_rng(); - scenario = ()) +function get_hybridcase_train_dataloader(rng::AbstractRNG, prob::HybridProblem; scenario = ()) return(prob.train_loader) end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index df9bfcf..e34317a 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -22,9 +22,11 @@ include("ModelApplicator.jl") export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler include("GPUDataHandler.jl") -export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_float_type, gen_hybridcase_synthetic, +export AbstractHybridCase, get_hybridcase_MLapplicator, get_hybridcase_PBmodel, + get_hybridcase_float_type, gen_hybridcase_synthetic, get_hybridcase_par_templates, get_hybridcase_transforms, get_hybridcase_train_dataloader, get_hybridcase_neg_logden_obs, + get_hybridcase_n_covar, gen_cov_pred include("hybrid_case.jl") diff --git a/src/elbo.jl b/src/elbo.jl index 1be7eab..f9b1118 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -5,20 +5,21 @@ It generates n_MC samples for each site, and uses these to compute the expected value of the likelihood of observations. ## Arguments -- rng: random number generator (ignored on CUDA, if ϕ is a AbstractGPUArray) -- g: machine learning model -- f: mechanistic model -- py: negative log-likelihood of observations given predictions: - `function(y_ob, y_pred, y_unc)` -- ϕ: flat vector of parameters +- `rng`: random number generator (ignored on CUDA, if ϕ is a AbstractGPUArray) +- `ϕ`: flat vector of parameters including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc), interpreted by interpreters.μP_ϕg_unc and interpreters.PMs -- y_ob: matrix of observations (n_obs x n_site_batch) -- y_unc: observation uncertainty provided to py (same size as y_ob) -- xM: matrix of covariates (n_cov x n_site_batch) -- xP: model drivers, iterable of (n_site_batch) -- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params -- n_MC: number of MonteCarlo samples from the distribution of parameters to simulate +- `g`: machine learning model +- `transPMs`: Transformations as generated by get_transPMs returned from init_hybrid_params +- `f`: mechanistic model +- `py`: negative log-likelihood of observations given predictions: + `function(y_ob, y_pred, y_unc)` +- `xM`: matrix of covariates (n_cov x n_site_batch) +- `xP`: model drivers, iterable of (n_site_batch) +- `y_ob`: matrix of observations (n_obs x n_site_batch) +- `y_unc`: observation uncertainty provided to py (same size as y_ob) +- interpreters: +- `n_MC`: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. """ function neg_elbo_transnorm_gf(rng, g, f, py, ϕ::AbstractVector, y_ob, y_unc, @@ -96,11 +97,9 @@ end Extract relevant parameters from θ and return n_MC generated draws together with the vector of standard deviations, σ. -Necessary typestable information on number of compponents are provided with -ComponentMarshellers -- marsh_pmu(n_θP, n_θMs, Unc=n_θUnc) -- marsh_batch(n_batch) -- marsh_unc(n_UncP, n_UncM, n_UncCorr) +## Arguments +`int_unc`: Interpret vector as ComponentVector with components + ρsP, ρsM, logσ2_logP, coef_logσ2_logMs(intercept + slope), """ function sample_ζ_norm0(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs::AbstractMatrix, args...; n_MC, cor_starts) diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index a85dcdf..a803710 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -8,10 +8,10 @@ For a specific case, provide functions that specify details - `get_hybridcase_neg_logden_obs` - `get_hybridcase_par_templates` - `get_hybridcase_transforms` -- `get_hybridcase_sizes` - `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`) optionally - `gen_hybridcase_synthetic` +- `get_hybridcase_n_covar` (defaults to number of rows in xM in train_dataloader ) - `get_hybridcase_float_type` (defaults to `eltype(θM)`) - `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`) """ @@ -79,19 +79,31 @@ Return a NamedTupe of """ function get_hybridcase_transforms end +# """ +# get_hybridcase_par_templates(::AbstractHybridCase; scenario) +# Provide a NamedTuple of number of +# - n_covar: covariates xM +# - n_site: all sites in the data +# - n_batch: sites in one minibatch during fitting +# - n_θM, n_θP: entries in parameter vectors +# """ +# function get_hybridcase_sizes end + """ - get_hybridcase_par_templates(::AbstractHybridCase; scenario) + get_hybridcase_n_covar(::AbstractHybridCase; scenario) -Provide a NamedTuple of number of -- n_covar: covariates xM -- n_site: all sites in the data -- n_batch: sites in one minibatch during fitting -- n_θM, n_θP: entries in parameter vectors +Provide the number of covariates. Default returns the number of rows in `xM` from +`get_hybridcase_train_dataloader`. """ -function get_hybridcase_sizes end +function get_hybridcase_n_covar(case::AbstractHybridCase; scenario) + train_loader = get_hybridcase_train_dataloader(Random.default_rng(), case; scenario) + (xM, xP, y_o, y_unc) = first(train_loader) + n_covar = size(xM, 1) + return(n_covar) +end """ - gen_hybridcase_synthetic(::AbstractHybridCase, rng; scenario) + gen_hybridcase_synthetic([rng,] ::AbstractHybridCase; scenario) Setup synthetic data, a NamedTuple of - xM: matrix of covariates, with one column per site @@ -114,7 +126,7 @@ function get_hybridcase_float_type(case::AbstractHybridCase; scenario=()) end """ - get_hybridcase_train_dataloader(::AbstractHybridCase, rng; scenario) + get_hybridcase_train_dataloader([rng,] ::AbstractHybridCase; scenario) Return a DataLoader that provides a tuple of - `xM`: matrix of covariates, with one column per site @@ -122,15 +134,21 @@ Return a DataLoader that provides a tuple of - `y_o`: matrix of observations with added noise, with one column per site - `y_unc`: matrix `sizeof(y_o)` of uncertainty information """ -function get_hybridcase_train_dataloader(case::AbstractHybridCase, rng::AbstractRNG; +function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridCase; scenario = ()) - (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(case, rng; scenario) - (; n_batch) = get_hybridcase_sizes(case; scenario) + (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario) + n_batch = 10 xM_gpu = :use_flux ∈ scenario ? CuArray(xM) : xM train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) return(train_loader) end +function get_hybridcase_train_dataloader(case::AbstractHybridCase; scenario = ()) + rng::AbstractRNG = Random.default_rng() + get_hybridcase_train_dataloader(rng, case; scenario) +end + + """ get_hybridcase_cor_starts(case::AbstractHybridCase; scenario) diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 2306d3a..f32838a 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -51,23 +51,20 @@ construct_problem = () -> begin rng = StableRNG(111) # dependency on DeoubleMMCase -> take care of changes in covariates (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc - ) = gen_hybridcase_synthetic(DoubleMM.DoubleMMCase(), rng) + ) = gen_hybridcase_synthetic(rng, DoubleMM.DoubleMMCase()) py = neg_logden_indep_normal train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize=n_batch) - # HybridProblem(θP, θM, transM, transP, n_covar, n_batch, f_doubleMM_with_global, - # g, ϕg, train_loader) HybridProblem(θP, θM, g_chain, f_doubleMM_with_global, py, - transM, transP, n_covar, n_batch, train_loader, cov_starts) + transM, transP, train_loader, cov_starts) end prob = construct_problem(); scenario = (:default,) -#(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(prob; scenario) - @testset "loss_gf" begin #----------- fit g and θP to y_o + rng = StableRNG(111) g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario) - train_loader = get_hybridcase_train_dataloader(prob; scenario) + train_loader = get_hybridcase_train_dataloader(rng, prob; scenario) (xM, xP, y_o, y_unc) = first(train_loader) f = get_hybridcase_PBmodel(prob; scenario) par_templates = get_hybridcase_par_templates(prob; scenario) @@ -105,7 +102,7 @@ import Flux @testset "neg_elbo_transnorm_gf cpu" begin rng = StableRNG(111) g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine) - train_loader = get_hybridcase_train_dataloader(prob) + train_loader = get_hybridcase_train_dataloader(rng, prob) (xM, xP, y_o, y_unc) = first(train_loader) n_batch = size(y_o, 2) f = get_hybridcase_PBmodel(prob) diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 6313369..e33d4c8 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -17,11 +17,9 @@ scenario = (:default,) par_templates = get_hybridcase_par_templates(case; scenario) -(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) - rng = StableRNG(111) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(case, rng; scenario); +) = gen_hybridcase_synthetic(rng, case; scenario); @testset "gen_hybridcase_synthetic" begin @test isapprox( @@ -31,7 +29,7 @@ rng = StableRNG(111) # test same results for same rng rng2 = StableRNG(111) - gen2 = gen_hybridcase_synthetic(case, rng2; scenario); + gen2 = gen_hybridcase_synthetic(rng2, case; scenario); @test gen2.y_o == y_o end @@ -79,6 +77,7 @@ end #p = p0 = vcat(ϕg_opt1, par_templates.θP); # almost true # Pass the site-data for the batches as separate vectors wrapped in a tuple + n_batch = 10 train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc), batchsize = n_batch) # get_hybridcase_train_dataloader recreates synthetic data different θ_true #train_loader = get_hybridcase_train_dataloader(case, rng; scenario) @@ -99,7 +98,7 @@ end l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) #l1, y_pred_global, y_pred, θMs_pred = loss_gf(p0, xM, xP, y_o, y_unc); θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true)) - @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) + #TODO @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.15) @test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 @test cor(θMs_true[:,1], θMs_pred[:,1]) > 0.9 @test cor(θMs_true[:,2], θMs_pred[:,2]) > 0.9 diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 16597f8..96f2d2b 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -30,7 +30,7 @@ n_batch = 10 n_θM, n_θP = values(map(length, get_hybridcase_par_templates(case; scenario))) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc -) = gen_hybridcase_synthetic(case, rng; scenario); +) = gen_hybridcase_synthetic(rng, case; scenario); py = neg_logden_indep_normal diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index e60e188..053c487 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -19,15 +19,10 @@ const case = DoubleMM.DoubleMMCase() #const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) -(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +n_θM, n_θP = length.(values(get_hybridcase_par_templates(case; scenario))) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o -) = gen_hybridcase_synthetic(case, rng; scenario) - -# n_site = 2 -# n_θP, n_θM = length(θ_true.θP), length(θ_true.θM) -# σ_θM = θ_true.θM .* 0.1 # 10% around expected -# θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM +) = gen_hybridcase_synthetic(rng, case; scenario) # set to 0.02 rather than zero for debugging non-zero correlations ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02 From df48ad4815e62ffb8c6c92944aac928a7a25c6d6 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 30 Jan 2025 10:09:47 +0100 Subject: [PATCH 5/7] reorder arguments of elbo random generator parameters functions drivers observation and uncertainties --- dev/doubleMM.jl | 4 ++-- src/elbo.jl | 14 +++++++++----- test/test_HybridProblem.jl | 20 ++++++++++---------- test/test_elbo.jl | 24 ++++++++++++------------ 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 42c4912..7547d23 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -188,8 +188,8 @@ g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario); end function fcost(ϕ, xM, y_o, y_unc) - neg_elbo_transnorm_gf(rng, g_flux, f, py, CA.getdata(ϕ), y_o, y_unc, - xM, xP, transPMs_batch, map(get_concrete, interpreters); + neg_elbo_transnorm_gf(rng, CA.getdata(ϕ), g_flux, transPMs_batch, f, py, + xM, xP, y_o, y_unc, map(get_concrete, interpreters); n_MC = 8) end fcost(ϕ, xM_gpu[:, 1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch]) diff --git a/src/elbo.jl b/src/elbo.jl index f9b1118..3ac23b7 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -6,8 +6,7 @@ expected value of the likelihood of observations. ## Arguments - `rng`: random number generator (ignored on CUDA, if ϕ is a AbstractGPUArray) -- `ϕ`: flat vector of parameters - including parameter of f (ϕ_P), of g (ϕ_Ms), and of VI (ϕ_unc), +- `ϕ`: flat vector of parameters, interpreted by interpreters interpreted by interpreters.μP_ϕg_unc and interpreters.PMs - `g`: machine learning model - `transPMs`: Transformations as generated by get_transPMs returned from init_hybrid_params @@ -18,12 +17,17 @@ expected value of the likelihood of observations. - `xP`: model drivers, iterable of (n_site_batch) - `y_ob`: matrix of observations (n_obs x n_site_batch) - `y_unc`: observation uncertainty provided to py (same size as y_ob) -- interpreters: +- `interpreters`: NamedTuple as generated by `gen_hybridcase_synthetic` with entries: + - `μP_ϕg_unc`: extract components of parameter of + 1) means of global PBM, 2) ML-weights, and 3) additional parameters of approximation q + - `PMs`: assign components to PBM parameters 1 global, 2 matrix of n_site column vectors + - `int_unc` (can be omitted, if `μP_ϕg_unc(ϕ).unc` is already a ComponentVector) - `n_MC`: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. """ -function neg_elbo_transnorm_gf(rng, g, f, py, ϕ::AbstractVector, y_ob, y_unc, - xM::AbstractMatrix, xP, transPMs, interpreters::NamedTuple; +function neg_elbo_transnorm_gf(rng, ϕ::AbstractVector, g, transPMs, f, py, + xM::AbstractMatrix, xP, y_ob, y_unc, + interpreters::NamedTuple; n_MC=3, gpu_data_handler = get_default_GPUHandler(), cor_starts=(P=(1,),M=(1,)) ) diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index f32838a..f61f104 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -116,14 +116,14 @@ import Flux py = get_hybridcase_neg_logden_obs(prob) - cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o, y_unc, - xM, xP, transPMs_batch, map(get_concrete, interpreters); + cost = neg_elbo_transnorm_gf(rng, ϕ_ini, g, transPMs_batch, f, py, + xM, xP, y_o, y_unc, map(get_concrete, interpreters); n_MC=8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, - xM, xP, transPMs_batch, map(get_concrete, interpreters); - n_MC=8), + ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py, + xM, xP, y_o, y_unc, map(get_concrete, interpreters); + n_MC=8), CA.getdata(ϕ_ini)) @test gr[1] isa Vector @@ -144,14 +144,14 @@ import Flux ϕ_ini.ϕg = ϕg0 ϕ = CuArray(CA.getdata(ϕ_ini)) xMg = CuArray(xM) - cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, - xMg, xP, transPMs_batch, map(get_concrete, interpreters); + cost = neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py, + xMg, xP, y_o, y_unc, map(get_concrete, interpreters); n_MC=8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o, y_unc, - xMg, xP, transPMs_batch, map(get_concrete, interpreters); - n_MC=8), + ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py, + xMg, xP, y_o, y_unc, map(get_concrete, interpreters); + n_MC=8), ϕ) @test gr[1] isa CuVector @test eltype(gr[1]) == get_hybridcase_float_type(prob) diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 96f2d2b..e083d2d 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -80,13 +80,15 @@ if CUDA.functional() end @testset "neg_elbo_transnorm_gf cpu" begin - cost = neg_elbo_transnorm_gf(rng, g, f, py, ϕ_ini, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], - xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters); + cost = neg_elbo_transnorm_gf(rng, ϕ_ini, g, transPMs_batch, f, py, + xM[:, 1:n_batch], xP[1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + map(get_concrete, interpreters); n_MC = 8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf(rng, g, f, py, ϕ, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], - xM[:, 1:n_batch], xP[1:n_batch], transPMs_batch, map(get_concrete, interpreters); + ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g, transPMs_batch, f, py, + xM[:, 1:n_batch], xP[1:n_batch], y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + map(get_concrete, interpreters); n_MC = 8), CA.getdata(ϕ_ini)) @test gr[1] isa Vector @@ -97,17 +99,15 @@ if CUDA.functional() ϕ = CuArray(CA.getdata(ϕ_ini)) xMg_batch = CuArray(xM[:, 1:n_batch]) xP_batch = xP[1:n_batch] # used in f which runs on CPU - cost = neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ, - y_o[:, 1:n_batch], y_unc[:, 1:n_batch], - xMg_batch, xP_batch, - transPMs_batch, map(get_concrete, interpreters); + cost = neg_elbo_transnorm_gf(rng, ϕ, g_flux, transPMs_batch, f, py, + xMg_batch, xP_batch, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + map(get_concrete, interpreters); n_MC = 8) @test cost isa Float64 gr = Zygote.gradient( - ϕ -> neg_elbo_transnorm_gf(rng, g_flux, f, py, ϕ, - y_o[:, 1:n_batch], y_unc[:, 1:n_batch], - xMg_batch, xP_batch, - transPMs_batch, map(get_concrete, interpreters); + ϕ -> neg_elbo_transnorm_gf(rng, ϕ, g_flux, transPMs_batch, f, py, + xMg_batch, xP_batch, y_o[:, 1:n_batch], y_unc[:, 1:n_batch], + map(get_concrete, interpreters); n_MC = 8), ϕ) @test gr[1] isa CuVector From 36c63be5f32ada6d47a52dddbe1329c50f5c4a1b Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 30 Jan 2025 12:28:34 +0100 Subject: [PATCH 6/7] remove MLEngine from get_hybridcase_MLapplicator better infer it from scenario, although this looses some type-stability. --- dev/doubleMM.jl | 19 ++++++++----- ext/HybridVariationalInferenceFluxExt.jl | 5 ++-- ...bridVariationalInferenceSimpleChainsExt.jl | 5 ++-- src/DoubleMM/f_doubleMM.jl | 7 +++++ src/HybridProblem.jl | 2 +- src/HybridVariationalInference.jl | 1 + src/ModelApplicator.jl | 28 +++++++++++++++++++ src/hybrid_case.jl | 11 +++----- test/test_HybridProblem.jl | 6 ++-- test/test_doubleMM.jl | 5 ++-- test/test_elbo.jl | 7 ++--- test/test_sample_zeta.jl | 1 - 12 files changed, 66 insertions(+), 31 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 7547d23..cc70fa2 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -7,7 +7,7 @@ using Statistics using ComponentArrays: ComponentArrays as CA using SimpleChains -import Flux # to allow for FluxMLEngine and cpu() +import Flux using MLUtils import Zygote @@ -17,20 +17,22 @@ using Bijectors using UnicodePlots const case = DoubleMM.DoubleMMCase() -const MLengine = Val(nameof(SimpleChains)) -const FluxMLengine = Val(nameof(Flux)) scenario = (:default,) rng = StableRNG(111) par_templates = get_hybridcase_par_templates(case; scenario) -(; n_covar, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) +#n_covar = get_hybridcase_n_covar(case; scenario) +#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridcase_synthetic(rng, case; scenario); +n_covar = size(xM,1) + + #----- fit g to θMs_true -g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); +g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); (; transP, transM) = get_hybridcase_transforms(case; scenario) function loss_g(ϕg, x, g, transM) @@ -90,6 +92,8 @@ FT = get_hybridcase_float_type(case; scenario) θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM); ϕ_true = ϕ + + () -> begin coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951] logσ2_logP = CA.ComponentVector(r0 = -8.997, K2 = -5.893) @@ -162,7 +166,8 @@ mean_σ_o_MC = 0.006042 ϕ = CA.getdata(ϕ_ini) |> Flux.gpu; xM_gpu = xM |> Flux.gpu; -g_flux, _ = get_hybridcase_MLapplicator(case, FluxMLengine; scenario); +scenario_flux = (scenario..., :use_Flux) +g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux); # otpimize using LUX () -> begin @@ -200,7 +205,7 @@ gr = Zygote.gradient(fcost, gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...) train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) -#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_flux)) +#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux)) optf = Optimization.OptimizationFunction( (ϕ, data) -> begin diff --git a/ext/HybridVariationalInferenceFluxExt.jl b/ext/HybridVariationalInferenceFluxExt.jl index 4dacaf3..02b9fa4 100644 --- a/ext/HybridVariationalInferenceFluxExt.jl +++ b/ext/HybridVariationalInferenceFluxExt.jl @@ -35,11 +35,12 @@ end # HybridProblem(θP, θM, g, ϕg, args...; kwargs...) # end -function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:Flux}; +function HVI.construct_3layer_MLApplicator( + rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux}; scenario::NTuple = ()) (;θM) = get_hybridcase_par_templates(case; scenario) n_out = length(θM) - n_covar = 5 + n_covar = get_hybridcase_n_covar(case; scenario) #(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario) float_type = get_hybridcase_float_type(case; scenario) is_using_dropout = :use_dropout ∈ scenario diff --git a/ext/HybridVariationalInferenceSimpleChainsExt.jl b/ext/HybridVariationalInferenceSimpleChainsExt.jl index f201365..df1a122 100644 --- a/ext/HybridVariationalInferenceSimpleChainsExt.jl +++ b/ext/HybridVariationalInferenceSimpleChainsExt.jl @@ -19,8 +19,9 @@ end HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ) -function HVI.get_hybridcase_MLapplicator(rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains}; - scenario::NTuple=()) +function HVI.construct_3layer_MLApplicator( + rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains}; + scenario::NTuple = ()) n_covar = get_hybridcase_n_covar(case; scenario) FloatType = get_hybridcase_float_type(case; scenario) (;θM) = get_hybridcase_par_templates(case; scenario) diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 7db2300..ee6c21f 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -91,6 +91,13 @@ function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase; ) end +function HVI.get_hybridcase_MLapplicator( + rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase; scenario = ()) + ml_engine = select_ml_engine(; scenario) + construct_3layer_MLApplicator(rng, case, ml_engine; scenario) +end + + diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 56a5fe4..9f89801 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -54,7 +54,7 @@ function get_hybridcase_PBmodel(prob::HybridProblem; scenario::NTuple = ()) prob.f end -function get_hybridcase_MLapplicator(prob::HybridProblem, ml_engine; scenario::NTuple = ()); +function get_hybridcase_MLapplicator(prob::HybridProblem; scenario::NTuple = ()); prob.g, prob.ϕg end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index e34317a..6a56427 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -17,6 +17,7 @@ export ComponentArrayInterpreter, flatten1, get_concrete include("ComponentArrayInterpreter.jl") export AbstractModelApplicator, construct_ChainsApplicator +export construct_3layer_MLApplicator, select_ml_engine include("ModelApplicator.jl") export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index fb928b6..de14697 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -36,3 +36,31 @@ end # function construct_LuxApplicator end +""" + construct_3layer_MLApplicator( + rng::AbstractRNG, case::HVI.AbstractHybridCase, ; + scenario::NTuple = ()) + +`ml_engine` usually is of type `Val{Symbol}`, e.g. Val(:Flux). See `select_ml_engine`. +""" +function construct_3layer_MLApplicator end + +""" + select_ml_engine(;scenario) + +Returns a value type `Val{:Symbol}` to dispatch on the machine learning engine to use. +- defaults to `Val(:SimpleChains)` +- `:use_Lux ∈ scenario -> Val(:Lux)` +- `:use_Flux ∈ scenario -> Val(:Flux)` +""" +function select_ml_engine(;scenario) + if :use_Lux ∈ scenario + return Val(:Lux) + elseif :use_Flux ∈ scenario + return Val(:Flux) + else + # default + return Val(:SimpleChains) + end +end + diff --git a/src/hybrid_case.jl b/src/hybrid_case.jl index a803710..585a266 100644 --- a/src/hybrid_case.jl +++ b/src/hybrid_case.jl @@ -19,22 +19,19 @@ abstract type AbstractHybridCase end; """ - get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase, MLEngine; scenario=()) + get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase; scenario=()) Construct the machine learning model fro given problem case and ML-Framework and scenario. -The MLEngine is a value type of a Symbol, usually the name of the module, e.g. -`const MLengine = Val(nameof(SimpleChains))`. - returns a Tuple of - AbstractModelApplicator - initial parameter vector """ function get_hybridcase_MLapplicator end -function get_hybridcase_MLapplicator(case::AbstractHybridCase, MLEngine; scenario=()) - get_hybridcase_MLapplicator(Random.default_rng(), case, MLEngine; scenario) +function get_hybridcase_MLapplicator(case::AbstractHybridCase; scenario=()) + get_hybridcase_MLapplicator(Random.default_rng(), case; scenario) end """ @@ -138,7 +135,7 @@ function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridC scenario = ()) (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario) n_batch = 10 - xM_gpu = :use_flux ∈ scenario ? CuArray(xM) : xM + xM_gpu = :use_Flux ∈ scenario ? CuArray(xM) : xM train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch) return(train_loader) end diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index f61f104..1430cde 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -12,8 +12,6 @@ import Zygote using OptimizationOptimisers -const MLengine = Val(nameof(SimpleChains)) - construct_problem = () -> begin FT = Float32 θP = CA.ComponentVector{FT}(r0=0.3, K2=2.0) @@ -63,7 +61,7 @@ scenario = (:default,) @testset "loss_gf" begin #----------- fit g and θP to y_o rng = StableRNG(111) - g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine; scenario) + g, ϕg0 = get_hybridcase_MLapplicator(prob; scenario) train_loader = get_hybridcase_train_dataloader(rng, prob; scenario) (xM, xP, y_o, y_unc) = first(train_loader) f = get_hybridcase_PBmodel(prob; scenario) @@ -101,7 +99,7 @@ import Flux @testset "neg_elbo_transnorm_gf cpu" begin rng = StableRNG(111) - g, ϕg0 = get_hybridcase_MLapplicator(prob, MLengine) + g, ϕg0 = get_hybridcase_MLapplicator(prob) train_loader = get_hybridcase_train_dataloader(rng, prob) (xM, xP, y_o, y_unc) = first(train_loader) n_batch = size(y_o, 2) diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index e33d4c8..0c49d78 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -11,7 +11,6 @@ import Zygote using OptimizationOptimisers -const MLengine = Val(nameof(SimpleChains)) const case = DoubleMM.DoubleMMCase() scenario = (:default,) @@ -34,7 +33,7 @@ rng = StableRNG(111) end @testset "loss_g" begin - g, ϕg0 = get_hybridcase_MLapplicator(rng, case, MLengine; scenario); + g, ϕg0 = get_hybridcase_MLapplicator(rng, case; scenario); (;transP, transM) = get_hybridcase_transforms(case; scenario) function loss_g(ϕg, x, g, transM) @@ -67,7 +66,7 @@ end @testset "loss_gf" begin #----------- fit g and θP to y_o (without uncertainty, without transforming θP) - g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); + g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); (;transP, transM) = get_hybridcase_transforms(case; scenario) f = get_hybridcase_PBmodel(case; scenario) diff --git a/test/test_elbo.jl b/test/test_elbo.jl index e083d2d..28ce205 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -17,12 +17,11 @@ using GPUArraysCore: GPUArraysCore rng = StableRNG(111) const case = DoubleMM.DoubleMMCase() -const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) FT = get_hybridcase_float_type(case; scenario) #θsite_true = get_hybridcase_par_templates(case; scenario) -g, ϕg0 = get_hybridcase_MLapplicator(case, MLengine; scenario); +g, ϕg0 = get_hybridcase_MLapplicator(case; scenario); f = get_hybridcase_PBmodel(case; scenario) n_covar = 5 @@ -58,8 +57,8 @@ end; # setup g as FluxNN on gpu using Flux -FluxMLengine = Val(nameof(Flux)) -g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case, FluxMLengine; scenario) +scenario_flux = (scenario..., :use_Flux) +g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case; scenario = scenario_flux) if CUDA.functional() @testset "generate_ζ gpu" begin diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 053c487..b776fe8 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -16,7 +16,6 @@ using Bijectors rng = StableRNG(111) const case = DoubleMM.DoubleMMCase() -#const MLengine = Val(nameof(SimpleChains)) scenario = (:default,) n_θM, n_θP = length.(values(get_hybridcase_par_templates(case; scenario))) From 463531b52fd8c174599ce56aa7946aee928f8223 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Thu, 30 Jan 2025 12:40:24 +0100 Subject: [PATCH 7/7] fix error on calling CUDA before testing its functional --- .github/workflows/CI.yml | 3 ++- test/test_elbo.jl | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ae44819..796183d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -49,7 +49,8 @@ jobs: with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + #fail_ci_if_error: false + fail_ci_if_error: true docs: name: Documentation runs-on: ubuntu-latest diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 28ce205..dc26383 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -57,8 +57,11 @@ end; # setup g as FluxNN on gpu using Flux -scenario_flux = (scenario..., :use_Flux) -g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case; scenario = scenario_flux) + +if CUDA.functional() + scenario_flux = (scenario..., :use_Flux) + g_flux, ϕg0_flux_cpu = get_hybridcase_MLapplicator(case; scenario = scenario_flux) +end if CUDA.functional() @testset "generate_ζ gpu" begin