diff --git a/Project.toml b/Project.toml index 9d93b20..9479100 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Thomas Wutzler and contributors"] version = "1.0.0-DEV" [deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -14,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] @@ -28,6 +28,7 @@ HybridVariationalInferenceLuxExt = "Lux" HybridVariationalInferenceSimpleChainsExt = "SimpleChains" [compat] +Bijectors = "0.15.4" BlockDiagonals = "0.1.42" CUDA = "5.5.2" ChainRulesCore = "1.25" @@ -41,7 +42,6 @@ Random = "1.10.0" SimpleChains = "0.4" StatsBase = "0.34.4" StatsFuns = "1.3.2" -TransformVariables = "0.8.10" Zygote = "0.6.73" julia = "1.10" diff --git a/dev/Project.toml b/dev/Project.toml index 162ab31..b5d54ec 100644 --- a/dev/Project.toml +++ b/dev/Project.toml @@ -1,4 +1,5 @@ [deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -15,7 +16,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 5f23d33..370bdd0 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -8,13 +8,10 @@ using GPUArraysCore using LinearAlgebra using CUDA using ChainRulesCore -using TransformVariables +using Bijectors using Zygote # Zygote.@ignore CUDA.randn using BlockDiagonals -export inverse_ca -include("util_transformvariables.jl") - export ComponentArrayInterpreter, flatten1, get_concrete include("ComponentArrayInterpreter.jl") diff --git a/src/elbo.jl b/src/elbo.jl index 32e74f9..13fbac0 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -13,7 +13,7 @@ expected value of the likelihood of observations. interpreted by interpreters.μP_ϕg_unc and interpreters.PMs - y_ob: matrix of observations (n_obs x n_site_batch) - x: matrix of covariates (n_cov x n_site_batch) -- transPMs: Transformations with components P, Ms, similar to interpreters +- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params - n_MC: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. - logσ2y: observation uncertainty (log of the variance) @@ -27,7 +27,7 @@ function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractM ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension #ζi = first(eachcol(ζs_cpu)) nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi - y_pred_i, logjac = predict_y(ζi, f, transPMs) + y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs) nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y) nLy1 - logjac end) / n_MC @@ -51,10 +51,12 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret n_site = size(xM, 2) intm_PMs_gen = get_ca_int_PMs(n_site) trans_PMs_gen = get_transPMs(n_site) + interpreters_gen = (; interpreters..., PMs = intm_PMs_gen) ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM), - (; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred) + interpreters_gen; n_MC = n_sample_pred) ζs_cpu = gpu_data_handler(ζs) # - y_pred = stack(map(ζ -> first(predict_y(ζ, f, trans_PMs_gen)), eachcol(ζs_cpu))); + y_pred = stack(map(ζ -> first(predict_y( + ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu))); y_pred end @@ -77,9 +79,9 @@ function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix, #ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC) ζ = stack(map(eachcol(ζ_resid)) do r rc = interpreters.PMs(r) - ζP = μ_ζP .+ rc.θP + ζP = μ_ζP .+ rc.P μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g - ζMs = μ_ζMs .+ rc.θMs + ζMs = μ_ζMs .+ rc.Ms vcat(ζP, vec(ζMs)) end) ζ, σ @@ -166,9 +168,11 @@ Steps: - transform the parameters to original constrained space - Applies the mechanistic model for each site """ -function predict_y(ζi, f, transPMs) - θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating - θc = CA.ComponentVector(θtup) +function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter) + # θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating + # θc = CA.ComponentVector(θtup) + θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating + θc = int_PMs(θ) # TODO provide xP xP = fill((), size(θc.Ms,2)) y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index 955d8ba..d010916 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -14,12 +14,19 @@ Returns a NamedTuple of - `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters - `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator` - `n_batch`: the number of sites to predicted in each mini-batch -- `transP`, `transM`: the Transformations for the global and site-dependent parameters +- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent + parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`. + Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ), + because this direction is used much more often. """ -function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ) +function init_hybrid_params(θP, θM, ϕg, n_batch; + transP=elementwise(identity), transM=elementwise(identity)) n_θP = length(θP) n_θM = length(θM) n_ϕg = length(ϕg) + # check translating parameters - can match length? + _ = Bijectors.inverse(transP)(θP) + _ = Bijectors.inverse(transM)(θM) # zero correlation matrices ρsP = zeros(sum(1:(n_θP - 1))) ρsM = zeros(sum(1:(n_θM - 1))) @@ -28,39 +35,35 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ) coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) - ϕt = CA.ComponentVector(; - μP = θP, + ϕ = CA.ComponentVector(; + μP = inverse(transP)(θP), ϕg = ϕg, unc = ϕunc0); # get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM function get_transPMs_inner(n_site) - transPMs = as( - (P = as(Array, transP, n_θP), - Ms = as(Array, transM, n_θM, n_site))) + transMs = ntuple(i -> transM, n_site) + ranges = vcat([1:n_θP], [(n_θP + i0*n_θM) .+ (1:n_θM) for i0 in 0:(n_site-1)]) + transPMs = Stacked((transP, transMs...), ranges) + transPMs end end transPMs_batch = get_transPMs(n_batch) - trans_gu = as( - (μP = as(Array, asℝ₊, n_θP), - ϕg = as(Array, n_ϕg), - unc = as(Array, length(ϕunc0)))) - ϕ = inverse_ca(trans_gu, ϕt) - # trans_g = as( - # (μP = as(Array, asℝ₊, n_θP), - # ϕg = as(Array, n_ϕg))) - # + # ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0))) + # inv_trans_gu = Stacked( + # (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges)) + # ϕ = inv_trans_gu(CA.getdata(ϕt)) get_ca_int_PMs = let function get_ca_int_PMs_inner(n_site) - ComponentArrayInterpreter(CA.ComponentVector(; θP, - θMs = CA.ComponentMatrix( + ComponentArrayInterpreter(CA.ComponentVector(; P=θP, + Ms = CA.ComponentMatrix( zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site)))) end end interpreters = map(get_concrete, (; - μP_ϕg_unc = ComponentArrayInterpreter(ϕt), + μP_ϕg_unc = ComponentArrayInterpreter(ϕ), PMs = get_ca_int_PMs(n_batch), unc = ComponentArrayInterpreter(ϕunc0) )) diff --git a/src/util_transformvariables.jl b/src/util_transformvariables.jl deleted file mode 100644 index 8e008b8..0000000 --- a/src/util_transformvariables.jl +++ /dev/null @@ -1,21 +0,0 @@ -""" -Apply TransformVariables.inverse to ComponentArray, `ca`. - -- convert `ca` to a `NamedTuple` -- apply transformation -- convert back to `ComponentArray` -""" -function inverse_ca(trans, ca::CA.AbstractArray) - CA.ComponentArray( - TransformVariables.inverse(trans, cv2NamedTuple(ca)), - CA.getaxes(ca)) -end - -""" -Convert ComponentVector to NamedTuple of the first layer, i.e. keep -ComponentVectors in the second level. -""" -function cv2NamedTuple(ca::CA.ComponentVector) - g = ((k, CA.getdata(ca[k])) for k in keys(ca)) - (; g...) -end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index c4d7183..f1df2cf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -16,6 +17,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/test/test_elbo.jl b/test/test_elbo.jl index cbe2199..0a9f2f0 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -7,7 +7,8 @@ using StableRNGs using Random using SimpleChains using ComponentArrays: ComponentArrays as CA -using TransformVariables +#using TransformVariables +using Bijectors using Zygote using CUDA using GPUArraysCore: GPUArraysCore @@ -30,8 +31,11 @@ f = gen_hybridcase_PBmodel(case; scenario) logσ2y = 2 .* log.(σ_o) n_MC = 3 +transP = elementwise(exp) +transM = Stacked(elementwise(identity), elementwise(exp)) +#transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP_true, θMs_true[:, 1], ϕg0, n_batch; transP = asℝ₊, transM = asℝ₊); + θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM); ϕ_ini = ϕ () -> begin @@ -163,7 +167,11 @@ if CUDA.functional() end @testset "predict_gf cpu" begin - n_sample_pred = 200 + n_sample_pred = n_site = 200 + intm_PMs_gen = get_ca_int_PMs(n_site) + trans_PMs_gen = get_transPMs(n_site) + @test length(intm_PMs_gen) == 402 + @test trans_PMs_gen.length_in == 402 y_pred = predict_gf(rng, g, f, ϕ_ini, xM, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred) @test y_pred isa Array diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 3befbc9..1e01dd3 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -10,7 +10,7 @@ using GPUArraysCore: GPUArraysCore using Random #using SimpleChains using ComponentArrays: ComponentArrays as CA -using TransformVariables +using Bijectors #CUDA.device!(4) rng = StableRNG(111) @@ -23,7 +23,7 @@ scenario = (:default,) @testset "test_sample_zeta" begin (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o -) = gen_hybridcase_synthetic(case, rng; scenario) + ) = gen_hybridcase_synthetic(case, rng; scenario) # n_site = 2 # n_θP, n_θM = length(θ_true.θP), length(θ_true.θM) @@ -31,29 +31,24 @@ scenario = (:default,) # θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM # set to 0.02 rather than zero for debugging non-zero correlations - ρsP = zeros(sum(1:(n_θP - 1))) .+ 0.02 - ρsM = zeros(sum(1:(n_θM - 1))) .+ 0.02 + ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02 + ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02 ϕunc = CA.ComponentVector(; - logσ2_logP = fill(-10.0, n_θP), - coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), + logσ2_logP=fill(-10.0, n_θP), + coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) - transPMs = as( - (P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_site))) θ_true = θ = CA.ComponentVector(; - P = θP_true, - Ms = θMs_true) - transPMs = as(( - P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_site))) - ζ_true = inverse_ca(transPMs, θ_true) - ϕ_true = vcat(ζ_true, CA.ComponentVector(unc = ϕunc)) - ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc = ϕunc)) + P=θP_true, + Ms=θMs_true) + transPMs = elementwise(exp) # all parameters on LogNormal scale + ζ_true = inverse(transPMs)(θ_true) + ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc)) + ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc)) - interpreters = (; pmu = ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) + interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) n_MC = 3 @testset "sample_ζ_norm0 cpu" begin