From cb594da635b1e8e30d9083309a9aebc08b253664 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 17 Jan 2025 12:07:05 +0100 Subject: [PATCH 1/2] use Bijections.jl for variable transformations, rather than TransformVariables allows more flexibility for specifying and combining transformations combinations applicable (also the invers) to vectors (rather than tuples) Better support for AD on GPU --- Project.toml | 4 +-- dev/Project.toml | 2 +- src/HybridVariationalInference.jl | 5 +--- src/elbo.jl | 22 ++++++++++------- src/init_hybrid_params.jl | 41 +++++++++++++++++-------------- src/util_transformvariables.jl | 21 ---------------- test/Project.toml | 2 +- test/test_elbo.jl | 14 ++++++++--- test/test_sample_zeta.jl | 31 ++++++++++------------- 9 files changed, 64 insertions(+), 78 deletions(-) delete mode 100644 src/util_transformvariables.jl diff --git a/Project.toml b/Project.toml index 9d93b20..9479100 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Thomas Wutzler and contributors"] version = "1.0.0-DEV" [deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -14,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] @@ -28,6 +28,7 @@ HybridVariationalInferenceLuxExt = "Lux" HybridVariationalInferenceSimpleChainsExt = "SimpleChains" [compat] +Bijectors = "0.15.4" BlockDiagonals = "0.1.42" CUDA = "5.5.2" ChainRulesCore = "1.25" @@ -41,7 +42,6 @@ Random = "1.10.0" SimpleChains = "0.4" StatsBase = "0.34.4" StatsFuns = "1.3.2" -TransformVariables = "0.8.10" Zygote = "0.6.73" julia = "1.10" diff --git a/dev/Project.toml b/dev/Project.toml index 162ab31..b5d54ec 100644 --- a/dev/Project.toml +++ b/dev/Project.toml @@ -1,4 +1,5 @@ [deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -15,7 +16,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index 5f23d33..370bdd0 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -8,13 +8,10 @@ using GPUArraysCore using LinearAlgebra using CUDA using ChainRulesCore -using TransformVariables +using Bijectors using Zygote # Zygote.@ignore CUDA.randn using BlockDiagonals -export inverse_ca -include("util_transformvariables.jl") - export ComponentArrayInterpreter, flatten1, get_concrete include("ComponentArrayInterpreter.jl") diff --git a/src/elbo.jl b/src/elbo.jl index 32e74f9..13fbac0 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -13,7 +13,7 @@ expected value of the likelihood of observations. interpreted by interpreters.μP_ϕg_unc and interpreters.PMs - y_ob: matrix of observations (n_obs x n_site_batch) - x: matrix of covariates (n_cov x n_site_batch) -- transPMs: Transformations with components P, Ms, similar to interpreters +- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params - n_MC: number of MonteCarlo samples from the distribution of parameters to simulate using the mechanistic model f. - logσ2y: observation uncertainty (log of the variance) @@ -27,7 +27,7 @@ function neg_elbo_transnorm_gf(rng, g, f, ϕ::AbstractVector, y_ob, x::AbstractM ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension #ζi = first(eachcol(ζs_cpu)) nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi - y_pred_i, logjac = predict_y(ζi, f, transPMs) + y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs) nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y) nLy1 - logjac end) / n_MC @@ -51,10 +51,12 @@ function predict_gf(rng, g, f, ϕ::AbstractVector, xM::AbstractMatrix, interpret n_site = size(xM, 2) intm_PMs_gen = get_ca_int_PMs(n_site) trans_PMs_gen = get_transPMs(n_site) + interpreters_gen = (; interpreters..., PMs = intm_PMs_gen) ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM), - (; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred) + interpreters_gen; n_MC = n_sample_pred) ζs_cpu = gpu_data_handler(ζs) # - y_pred = stack(map(ζ -> first(predict_y(ζ, f, trans_PMs_gen)), eachcol(ζs_cpu))); + y_pred = stack(map(ζ -> first(predict_y( + ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu))); y_pred end @@ -77,9 +79,9 @@ function generate_ζ(rng, g, f, ϕ::AbstractVector, x::AbstractMatrix, #ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC) ζ = stack(map(eachcol(ζ_resid)) do r rc = interpreters.PMs(r) - ζP = μ_ζP .+ rc.θP + ζP = μ_ζP .+ rc.P μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g - ζMs = μ_ζMs .+ rc.θMs + ζMs = μ_ζMs .+ rc.Ms vcat(ζP, vec(ζMs)) end) ζ, σ @@ -166,9 +168,11 @@ Steps: - transform the parameters to original constrained space - Applies the mechanistic model for each site """ -function predict_y(ζi, f, transPMs) - θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating - θc = CA.ComponentVector(θtup) +function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter) + # θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating + # θc = CA.ComponentVector(θtup) + θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating + θc = int_PMs(θ) # TODO provide xP xP = fill((), size(θc.Ms,2)) y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index 955d8ba..e2b7f93 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -14,12 +14,19 @@ Returns a NamedTuple of - `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters - `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator` - `n_batch`: the number of sites to predicted in each mini-batch -- `transP`, `transM`: the Transformations for the global and site-dependent parameters +- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent + parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`. + Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ), + because this direction is used much more often. """ -function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ) +function init_hybrid_params(θP, θM, ϕg, n_batch; + transP=elementwise(identity), transM=elementwise(identity)) n_θP = length(θP) n_θM = length(θM) n_ϕg = length(ϕg) + # check translating parameters - can match lenght? + _ = Bijectors.inverse(transP)(θP) + _ = Bijectors.inverse(transM)(θM) # zero correlation matrices ρsP = zeros(sum(1:(n_θP - 1))) ρsM = zeros(sum(1:(n_θM - 1))) @@ -28,39 +35,35 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ) coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) - ϕt = CA.ComponentVector(; - μP = θP, + ϕ = CA.ComponentVector(; + μP = inverse(transP)(θP), ϕg = ϕg, unc = ϕunc0); # get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM function get_transPMs_inner(n_site) - transPMs = as( - (P = as(Array, transP, n_θP), - Ms = as(Array, transM, n_θM, n_site))) + transMs = ntuple(i -> transM, n_site) + ranges = vcat([1:n_θP], [(n_θP + i0*n_θM) .+ (1:n_θM) for i0 in 0:(n_site-1)]) + transPMs = Stacked((transP, transMs...), ranges) + transPMs end end transPMs_batch = get_transPMs(n_batch) - trans_gu = as( - (μP = as(Array, asℝ₊, n_θP), - ϕg = as(Array, n_ϕg), - unc = as(Array, length(ϕunc0)))) - ϕ = inverse_ca(trans_gu, ϕt) - # trans_g = as( - # (μP = as(Array, asℝ₊, n_θP), - # ϕg = as(Array, n_ϕg))) - # + # ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0))) + # inv_trans_gu = Stacked( + # (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges)) + # ϕ = inv_trans_gu(CA.getdata(ϕt)) get_ca_int_PMs = let function get_ca_int_PMs_inner(n_site) - ComponentArrayInterpreter(CA.ComponentVector(; θP, - θMs = CA.ComponentMatrix( + ComponentArrayInterpreter(CA.ComponentVector(; P=θP, + Ms = CA.ComponentMatrix( zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site)))) end end interpreters = map(get_concrete, (; - μP_ϕg_unc = ComponentArrayInterpreter(ϕt), + μP_ϕg_unc = ComponentArrayInterpreter(ϕ), PMs = get_ca_int_PMs(n_batch), unc = ComponentArrayInterpreter(ϕunc0) )) diff --git a/src/util_transformvariables.jl b/src/util_transformvariables.jl deleted file mode 100644 index 8e008b8..0000000 --- a/src/util_transformvariables.jl +++ /dev/null @@ -1,21 +0,0 @@ -""" -Apply TransformVariables.inverse to ComponentArray, `ca`. - -- convert `ca` to a `NamedTuple` -- apply transformation -- convert back to `ComponentArray` -""" -function inverse_ca(trans, ca::CA.AbstractArray) - CA.ComponentArray( - TransformVariables.inverse(trans, cv2NamedTuple(ca)), - CA.getaxes(ca)) -end - -""" -Convert ComponentVector to NamedTuple of the first layer, i.e. keep -ComponentVectors in the second level. -""" -function cv2NamedTuple(ca::CA.ComponentVector) - g = ((k, CA.getdata(ca[k])) for k in keys(ca)) - (; g...) -end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index c4d7183..f1df2cf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -16,6 +17,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/test/test_elbo.jl b/test/test_elbo.jl index cbe2199..0a9f2f0 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -7,7 +7,8 @@ using StableRNGs using Random using SimpleChains using ComponentArrays: ComponentArrays as CA -using TransformVariables +#using TransformVariables +using Bijectors using Zygote using CUDA using GPUArraysCore: GPUArraysCore @@ -30,8 +31,11 @@ f = gen_hybridcase_PBmodel(case; scenario) logσ2y = 2 .* log.(σ_o) n_MC = 3 +transP = elementwise(exp) +transM = Stacked(elementwise(identity), elementwise(exp)) +#transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP_true, θMs_true[:, 1], ϕg0, n_batch; transP = asℝ₊, transM = asℝ₊); + θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM); ϕ_ini = ϕ () -> begin @@ -163,7 +167,11 @@ if CUDA.functional() end @testset "predict_gf cpu" begin - n_sample_pred = 200 + n_sample_pred = n_site = 200 + intm_PMs_gen = get_ca_int_PMs(n_site) + trans_PMs_gen = get_transPMs(n_site) + @test length(intm_PMs_gen) == 402 + @test trans_PMs_gen.length_in == 402 y_pred = predict_gf(rng, g, f, ϕ_ini, xM, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred) @test y_pred isa Array diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 3befbc9..1e01dd3 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -10,7 +10,7 @@ using GPUArraysCore: GPUArraysCore using Random #using SimpleChains using ComponentArrays: ComponentArrays as CA -using TransformVariables +using Bijectors #CUDA.device!(4) rng = StableRNG(111) @@ -23,7 +23,7 @@ scenario = (:default,) @testset "test_sample_zeta" begin (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o -) = gen_hybridcase_synthetic(case, rng; scenario) + ) = gen_hybridcase_synthetic(case, rng; scenario) # n_site = 2 # n_θP, n_θM = length(θ_true.θP), length(θ_true.θM) @@ -31,29 +31,24 @@ scenario = (:default,) # θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM # set to 0.02 rather than zero for debugging non-zero correlations - ρsP = zeros(sum(1:(n_θP - 1))) .+ 0.02 - ρsM = zeros(sum(1:(n_θM - 1))) .+ 0.02 + ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02 + ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02 ϕunc = CA.ComponentVector(; - logσ2_logP = fill(-10.0, n_θP), - coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), + logσ2_logP=fill(-10.0, n_θP), + coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)), ρsP, ρsM) - transPMs = as( - (P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_site))) θ_true = θ = CA.ComponentVector(; - P = θP_true, - Ms = θMs_true) - transPMs = as(( - P = as(Array, asℝ₊, n_θP), - Ms = as(Array, asℝ₊, n_θM, n_site))) - ζ_true = inverse_ca(transPMs, θ_true) - ϕ_true = vcat(ζ_true, CA.ComponentVector(unc = ϕunc)) - ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc = ϕunc)) + P=θP_true, + Ms=θMs_true) + transPMs = elementwise(exp) # all parameters on LogNormal scale + ζ_true = inverse(transPMs)(θ_true) + ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc)) + ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc)) - interpreters = (; pmu = ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) + interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs) n_MC = 3 @testset "sample_ζ_norm0 cpu" begin From 81e656d7ddf2916b0176156cc7505ef9d71efa18 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 17 Jan 2025 12:09:49 +0100 Subject: [PATCH 2/2] typo --- src/init_hybrid_params.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/init_hybrid_params.jl b/src/init_hybrid_params.jl index e2b7f93..d010916 100644 --- a/src/init_hybrid_params.jl +++ b/src/init_hybrid_params.jl @@ -24,7 +24,7 @@ function init_hybrid_params(θP, θM, ϕg, n_batch; n_θP = length(θP) n_θM = length(θM) n_ϕg = length(ϕg) - # check translating parameters - can match lenght? + # check translating parameters - can match length? _ = Bijectors.inverse(transP)(θP) _ = Bijectors.inverse(transM)(θM) # zero correlation matrices