Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Thomas Wutzler <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Expand All @@ -28,6 +28,7 @@ HybridVariationalInferenceLuxExt = "Lux"
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"

[compat]
Bijectors = "0.15.4"
BlockDiagonals = "0.1.42"
CUDA = "5.5.2"
ChainRulesCore = "1.25"
Expand All @@ -41,7 +42,6 @@ Random = "1.10.0"
SimpleChains = "0.4"
StatsBase = "0.34.4"
StatsFuns = "1.3.2"
TransformVariables = "0.8.10"
Zygote = "0.6.73"
julia = "1.10"

Expand Down
2 changes: 1 addition & 1 deletion dev/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand All @@ -15,7 +16,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
5 changes: 1 addition & 4 deletions src/HybridVariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ using GPUArraysCore
using LinearAlgebra
using CUDA
using ChainRulesCore
using TransformVariables
using Bijectors
using Zygote # Zygote.@ignore CUDA.randn
using BlockDiagonals

export inverse_ca
include("util_transformvariables.jl")

export ComponentArrayInterpreter, flatten1, get_concrete
include("ComponentArrayInterpreter.jl")

Expand Down
22 changes: 13 additions & 9 deletions src/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
interpreted by interpreters.μP_ϕg_unc and interpreters.PMs
- y_ob: matrix of observations (n_obs x n_site_batch)
- x: matrix of covariates (n_cov x n_site_batch)
- transPMs: Transformations with components P, Ms, similar to interpreters
- transPMs: Transformations as generated by get_transPMs returned from init_hybrid_params
- n_MC: number of MonteCarlo samples from the distribution of parameters to simulate
using the mechanistic model f.
- logσ2y: observation uncertainty (log of the variance)
Expand All @@ -27,7 +27,7 @@
ζs_cpu = gpu_data_handler(ζs) # differentiable fetch to CPU in Flux package extension
#ζi = first(eachcol(ζs_cpu))
nLy = reduce(+, map(eachcol(ζs_cpu)) do ζi
y_pred_i, logjac = predict_y(ζi, f, transPMs)
y_pred_i, logjac = predict_y(ζi, f, transPMs, interpreters.PMs)

Check warning on line 30 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L30

Added line #L30 was not covered by tests
nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, logσ2y)
nLy1 - logjac
end) / n_MC
Expand All @@ -51,10 +51,12 @@
n_site = size(xM, 2)
intm_PMs_gen = get_ca_int_PMs(n_site)
trans_PMs_gen = get_transPMs(n_site)
interpreters_gen = (; interpreters..., PMs = intm_PMs_gen)

Check warning on line 54 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L54

Added line #L54 was not covered by tests
ζs, _ = generate_ζ(rng, g, f, CA.getdata(ϕ), CA.getdata(xM),
(; interpreters..., PMs = intm_PMs_gen); n_MC = n_sample_pred)
interpreters_gen; n_MC = n_sample_pred)
ζs_cpu = gpu_data_handler(ζs) #
y_pred = stack(map(ζ -> first(predict_y(ζ, f, trans_PMs_gen)), eachcol(ζs_cpu)));
y_pred = stack(map(ζ -> first(predict_y(

Check warning on line 58 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L58

Added line #L58 was not covered by tests
ζ, f, trans_PMs_gen, interpreters_gen.PMs)), eachcol(ζs_cpu)));
y_pred
end

Expand All @@ -77,9 +79,9 @@
#ζ_resid, σ = sample_ζ_norm0(rng, ϕ[1:2], reshape(ϕ[2 .+ (1:20)],2,:), ϕ[(end-length(interpreters.unc)+1):end], interpreters.unc; n_MC)
ζ = stack(map(eachcol(ζ_resid)) do r
rc = interpreters.PMs(r)
ζP = μ_ζP .+ rc.θP
ζP = μ_ζP .+ rc.P

Check warning on line 82 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L82

Added line #L82 was not covered by tests
μ_ζMs = μ_ζMs0 # g(x, ϕc.ϕ) # TODO provide ζP to g
ζMs = μ_ζMs .+ rc.θMs
ζMs = μ_ζMs .+ rc.Ms

Check warning on line 84 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L84

Added line #L84 was not covered by tests
vcat(ζP, vec(ζMs))
end)
ζ, σ
Expand Down Expand Up @@ -166,9 +168,11 @@
- transform the parameters to original constrained space
- Applies the mechanistic model for each site
"""
function predict_y(ζi, f, transPMs)
θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating
θc = CA.ComponentVector(θtup)
function predict_y(ζi, f, transPMs::Bijectors.Transform, int_PMs::AbstractComponentArrayInterpreter)

Check warning on line 171 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L171

Added line #L171 was not covered by tests
# θtup, logjac = transform_and_logjac(transPMs, ζi) # both allocating
# θc = CA.ComponentVector(θtup)
θ, logjac = Bijectors.with_logabsdet_jacobian(transPMs, ζi) # both allocating
θc = int_PMs(θ)

Check warning on line 175 in src/elbo.jl

View check run for this annotation

Codecov / codecov/patch

src/elbo.jl#L174-L175

Added lines #L174 - L175 were not covered by tests
# TODO provide xP
xP = fill((), size(θc.Ms,2))
y_pred_global, y_pred = f(θc.P, θc.Ms, xP) # TODO parallelize on CPU
Expand Down
41 changes: 22 additions & 19 deletions src/init_hybrid_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
- `θP`, `θM`: Template ComponentVectors of global parameters and ML-predicted parameters
- `ϕg`: vector of parameters to optimize, as returned by `gen_hybridcase_MLapplicator`
- `n_batch`: the number of sites to predicted in each mini-batch
- `transP`, `transM`: the Transformations for the global and site-dependent parameters
- `transP`, `transM`: the Bijector.Transformations for the global and site-dependent
parameters, e.g. `Stacked(elementwise(identity), elementwise(exp), elementwise(exp))`.
Its the transformation froing from unconstrained to constrained space: θ = Tinv(ζ),
because this direction is used much more often.
"""
function init_hybrid_params(θP, θM, ϕg, n_batch; transP=asℝ, transM=asℝ)
function init_hybrid_params(θP, θM, ϕg, n_batch;

Check warning on line 22 in src/init_hybrid_params.jl

View check run for this annotation

Codecov / codecov/patch

src/init_hybrid_params.jl#L22

Added line #L22 was not covered by tests
transP=elementwise(identity), transM=elementwise(identity))
n_θP = length(θP)
n_θM = length(θM)
n_ϕg = length(ϕg)
# check translating parameters - can match length?
_ = Bijectors.inverse(transP)(θP)
_ = Bijectors.inverse(transM)(θM)

Check warning on line 29 in src/init_hybrid_params.jl

View check run for this annotation

Codecov / codecov/patch

src/init_hybrid_params.jl#L28-L29

Added lines #L28 - L29 were not covered by tests
# zero correlation matrices
ρsP = zeros(sum(1:(n_θP - 1)))
ρsM = zeros(sum(1:(n_θM - 1)))
Expand All @@ -28,39 +35,35 @@
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
ρsP,
ρsM)
ϕt = CA.ComponentVector(;
μP = θP,
ϕ = CA.ComponentVector(;

Check warning on line 38 in src/init_hybrid_params.jl

View check run for this annotation

Codecov / codecov/patch

src/init_hybrid_params.jl#L38

Added line #L38 was not covered by tests
μP = inverse(transP)(θP),
ϕg = ϕg,
unc = ϕunc0);
#
get_transPMs = let transP=transP, transM=transM, n_θP=n_θP, n_θM=n_θM
function get_transPMs_inner(n_site)
transPMs = as(
(P = as(Array, transP, n_θP),
Ms = as(Array, transM, n_θM, n_site)))
transMs = ntuple(i -> transM, n_site)
ranges = vcat([1:n_θP], [(n_θP + i0*n_θM) .+ (1:n_θM) for i0 in 0:(n_site-1)])
transPMs = Stacked((transP, transMs...), ranges)
transPMs

Check warning on line 48 in src/init_hybrid_params.jl

View check run for this annotation

Codecov / codecov/patch

src/init_hybrid_params.jl#L45-L48

Added lines #L45 - L48 were not covered by tests
end
end
transPMs_batch = get_transPMs(n_batch)
trans_gu = as(
(μP = as(Array, asℝ₊, n_θP),
ϕg = as(Array, n_ϕg),
unc = as(Array, length(ϕunc0))))
ϕ = inverse_ca(trans_gu, ϕt)
# trans_g = as(
# (μP = as(Array, asℝ₊, n_θP),
# ϕg = as(Array, n_ϕg)))
#
# ranges = (P = 1:n_θP, ϕg = n_θP .+ (1:n_ϕg), unc = (n_θP + n_ϕg) .+ (1:length(ϕunc0)))
# inv_trans_gu = Stacked(
# (inverse(transP), elementwise(identity), elementwise(identity)), values(ranges))
# ϕ = inv_trans_gu(CA.getdata(ϕt))
get_ca_int_PMs = let
function get_ca_int_PMs_inner(n_site)
ComponentArrayInterpreter(CA.ComponentVector(; θP,
θMs = CA.ComponentMatrix(
ComponentArrayInterpreter(CA.ComponentVector(; P=θP,

Check warning on line 58 in src/init_hybrid_params.jl

View check run for this annotation

Codecov / codecov/patch

src/init_hybrid_params.jl#L58

Added line #L58 was not covered by tests
Ms = CA.ComponentMatrix(
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i = 1:n_site))))
end

end
interpreters = map(get_concrete,
(;
μP_ϕg_unc = ComponentArrayInterpreter(ϕt),
μP_ϕg_unc = ComponentArrayInterpreter(ϕ),
PMs = get_ca_int_PMs(n_batch),
unc = ComponentArrayInterpreter(ϕunc0)
))
Expand Down
21 changes: 0 additions & 21 deletions src/util_transformvariables.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -16,6 +17,5 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
14 changes: 11 additions & 3 deletions test/test_elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using StableRNGs
using Random
using SimpleChains
using ComponentArrays: ComponentArrays as CA
using TransformVariables
#using TransformVariables
using Bijectors
using Zygote
using CUDA
using GPUArraysCore: GPUArraysCore
Expand All @@ -30,8 +31,11 @@ f = gen_hybridcase_PBmodel(case; scenario)

logσ2y = 2 .* log.(σ_o)
n_MC = 3
transP = elementwise(exp)
transM = Stacked(elementwise(identity), elementwise(exp))
#transM = Stacked(elementwise(identity), elementwise(exp), elementwise(exp)) # test mismatch
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
θP_true, θMs_true[:, 1], ϕg0, n_batch; transP = asℝ₊, transM = asℝ₊);
θP_true, θMs_true[:, 1], ϕg0, n_batch; transP, transM);
ϕ_ini = ϕ

() -> begin
Expand Down Expand Up @@ -163,7 +167,11 @@ if CUDA.functional()
end

@testset "predict_gf cpu" begin
n_sample_pred = 200
n_sample_pred = n_site = 200
intm_PMs_gen = get_ca_int_PMs(n_site)
trans_PMs_gen = get_transPMs(n_site)
@test length(intm_PMs_gen) == 402
@test trans_PMs_gen.length_in == 402
y_pred = predict_gf(rng, g, f, ϕ_ini, xM, map(get_concrete, interpreters);
get_transPMs, get_ca_int_PMs, n_sample_pred)
@test y_pred isa Array
Expand Down
31 changes: 13 additions & 18 deletions test/test_sample_zeta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using GPUArraysCore: GPUArraysCore
using Random
#using SimpleChains
using ComponentArrays: ComponentArrays as CA
using TransformVariables
using Bijectors

#CUDA.device!(4)
rng = StableRNG(111)
Expand All @@ -23,37 +23,32 @@ scenario = (:default,)

@testset "test_sample_zeta" begin
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
) = gen_hybridcase_synthetic(case, rng; scenario)
) = gen_hybridcase_synthetic(case, rng; scenario)

# n_site = 2
# n_θP, n_θM = length(θ_true.θP), length(θ_true.θM)
# σ_θM = θ_true.θM .* 0.1 # 10% around expected
# θMs_true = θ_true.θM .+ randn(n_θM, n_site) .* σ_θM

# set to 0.02 rather than zero for debugging non-zero correlations
ρsP = zeros(sum(1:(n_θP - 1))) .+ 0.02
ρsM = zeros(sum(1:(n_θM - 1))) .+ 0.02
ρsP = zeros(sum(1:(n_θP-1))) .+ 0.02
ρsM = zeros(sum(1:(n_θM-1))) .+ 0.02

ϕunc = CA.ComponentVector(;
logσ2_logP = fill(-10.0, n_θP),
coef_logσ2_logMs = reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
logσ2_logP=fill(-10.0, n_θP),
coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
ρsP,
ρsM)

transPMs = as(
(P = as(Array, asℝ₊, n_θP),
Ms = as(Array, asℝ₊, n_θM, n_site)))
θ_true = θ = CA.ComponentVector(;
P = θP_true,
Ms = θMs_true)
transPMs = as((
P = as(Array, asℝ₊, n_θP),
Ms = as(Array, asℝ₊, n_θM, n_site)))
ζ_true = inverse_ca(transPMs, θ_true)
ϕ_true = vcat(ζ_true, CA.ComponentVector(unc = ϕunc))
ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc = ϕunc))
P=θP_true,
Ms=θMs_true)
transPMs = elementwise(exp) # all parameters on LogNormal scale
ζ_true = inverse(transPMs)(θ_true)
ϕ_true = vcat(ζ_true, CA.ComponentVector(unc=ϕunc))
ϕ_cpu = vcat(ζ_true .+ 0.01, CA.ComponentVector(unc=ϕunc))

interpreters = (; pmu = ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs)
interpreters = (; pmu=ComponentArrayInterpreter(ϕ_true)) #, M=int_θM, PMs=int_θPMs)

n_MC = 3
@testset "sample_ζ_norm0 cpu" begin
Expand Down
Loading