diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 8691365..d443ee8 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -28,35 +28,34 @@ gdev = :use_gpu ∈ scenario ? gpu_device() : identity cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity #------ setup synthetic data and training data loader -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc +prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario); +(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario); -#n_site = get_hybridproblem_n_site(DoubleMM.DoubleMMCase(); scenario) +n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario) ζP_true, ζMs_true = log.(θP_true), log.(θMs_true) i_sites = 1:n_site -xM_cpu = xM; -xM = xM_cpu |> gdev; -get_train_loader = (; n_batch, kwargs...) -> MLUtils.DataLoader( +n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario) +train_dataloader = MLUtils.DataLoader( (xM, xP, y_o, y_unc, 1:n_site); batchsize = n_batch, partial = false) σ_o = exp.(y_unc[:, 1] / 2) - # assign the train_loader, otherwise it eatch time creates another version of synthetic data -prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader) +prob0 = HVI.update(prob0_; train_dataloader); #tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario) #------- pointwise hybrid model fit -solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_batch = 30) +solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01)) #solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 30) #solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 10) #solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200) -n_batches_in_epoch = n_site ÷ solver_point.n_batch +n_batches_in_epoch = n_site ÷ n_batch n_epoch = 80 (; ϕ, resopt, probo) = solve(prob0, solver_point; scenario, rng, callback = callback_loss(n_batches_in_epoch * 10), maxiters = n_batches_in_epoch * n_epoch); # update the problem with optimized parameters prob0o = probo; -y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario); +y_pred_global, y_pred, θMs = gf(prob0o, scenario); plt = scatterplot(θMs_true[1, :], θMs[1, :]); lineplot!(plt, 0, 1) scatterplot(θMs_true[2, :], θMs[2, :]) @@ -149,10 +148,10 @@ probh = prob0o # start from point optimized to infer uncertainty #probh = prob1o # start from point optimized to infer uncertainty #probh = prob0 # start from no information solver_post = HybridPosteriorSolver(; - alg = OptimizationOptimisers.Adam(0.01), n_batch = min(50, n_site), n_MC = 3) + alg = OptimizationOptimisers.Adam(0.01), n_MC = 3) #solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200) -n_batches_in_epoch = n_site ÷ solver_post.n_batch -n_epoch = 80 +n_batches_in_epoch = n_site ÷ n_batch +n_epoch = 40 (; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario, rng, callback = callback_loss(n_batches_in_epoch * 5), maxiters = n_batches_in_epoch * n_epoch, @@ -213,6 +212,7 @@ end n_sample_pred = 400 (; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred); (θ2_indep, y2_indep) = (θ, y) + #(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed end () -> begin # otpimize using LUX @@ -246,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :]) # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 400 -(; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP; scenario, n_sample_pred); +(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred); (θ2, y2) = (θ, y) size(y) # n_obs x n_site, n_sample_pred size(θ) # n_θP + n_site * n_θM x n_sample @@ -506,12 +506,13 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread using JLD2 fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2" jldsave(fname, false, IOStream; chain) - chain = load(fname, "chain"; iotype = IOStream) + chain = load(fname, "chain"; iotype = IOStream); end #ζi = first(eachrow(Array(chain))) +f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true) ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain))); -(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen); +(; θ, y) = HVI.predict_ζf(ζs, f_allsites, xP, trans_PMs_gen, intm_PMs_gen); (ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y); diff --git a/dev/negLogDensity.pdf b/dev/negLogDensity.pdf deleted file mode 100644 index dec1226..0000000 Binary files a/dev/negLogDensity.pdf and /dev/null differ diff --git a/dev/r1_density.pdf b/dev/r1_density.pdf deleted file mode 100644 index a05680a..0000000 Binary files a/dev/r1_density.pdf and /dev/null differ diff --git a/dev/ys_density.pdf b/dev/ys_density.pdf deleted file mode 100644 index 4ca94db..0000000 Binary files a/dev/ys_density.pdf and /dev/null differ diff --git a/src/AbstractHybridProblem.jl b/src/AbstractHybridProblem.jl index 67b7e95..71bdff3 100644 --- a/src/AbstractHybridProblem.jl +++ b/src/AbstractHybridProblem.jl @@ -12,7 +12,7 @@ For a specific prob, provide functions that specify details - `get_hybridproblem_train_dataloader` (may use `construct_dataloader_from_synthetic`) - `get_hybridproblem_priors` - `get_hybridproblem_n_covar` -- `get_hybridproblem_n_site` +- `get_hybridproblem_n_site_and_batch` optionally - `gen_hybridproblem_synthetic` - `get_hybridproblem_float_type` (defaults to `eltype(θM)`) @@ -125,11 +125,11 @@ function get_hybridproblem_pbmpar_covars(::AbstractHybridProblem; scenario) end """ - get_hybridproblem_n_site(::AbstractHybridProblem; scenario) + get_hybridproblem_n_site_and_batch(::AbstractHybridProblem; scenario) Provide the number of sites. """ -function get_hybridproblem_n_site end +function get_hybridproblem_n_site_and_batch end """ @@ -172,15 +172,10 @@ function get_hybridproblem_train_dataloader end scenario = (), n_batch) Construct a dataloader based on `gen_hybridproblem_synthetic`. -gdev is applied to xM. -If :f_on_gpu is in scenario tuple, gdev is also applied to `xP`, `y_o`, and `y_unc`, -to put the entire data to gpu. -Alternatively, gdev could be applied to the dataloader, then for each -iteration the subset of data is separately transferred to gpu. """ function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHybridProblem; scenario = (), n_batch, - gdev = :use_gpu ∈ scenario ? gpu_device() : identity, + #gdev = :use_gpu ∈ scenario ? gpu_device() : identity, ) (; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario) n_site = size(xM,2) @@ -188,14 +183,40 @@ function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHyb @assert size(y_o,2) == n_site @assert size(y_unc,2) == n_site i_sites = 1:n_site - xM_dev = gdev(xM) - xP_dev, y_o_dev, y_unc_dev = :f_on_gpu ∈ scenario ? - (gdev(xP), gdev(y_o), gdev(y_unc)) : (xP, y_o, y_unc) - train_loader = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites); + train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites); batchsize = n_batch, partial = false) return (train_loader) end + +""" + gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader, + scenario = (), + gdev = gpu_device(), + gdev_M = :use_gpu ∈ scenario ? gdev : identity, + gdev_P = :f_on_gpu ∈ scenario ? gdev : identity, + batchsize = dataloader.batchsize, + partial = dataloader.partial + ) + +Put relevant parts of the DataLoader to gpu, depending on scenario. +""" +function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader; + scenario = (), + gdev = gpu_device(), + gdev_M = :use_gpu ∈ scenario ? gdev : identity, + gdev_P = :f_on_gpu ∈ scenario ? gdev : identity, + batchsize = dataloader.batchsize, + partial = dataloader.partial + ) + xM, xP, y_o, y_unc, i_sites = dataloader.data + xM_dev = gdev_M(xM) + xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc)) + train_loader_dev = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites); + batchsize, partial) + return(train_loader_dev) +end + # function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ()) # rng::AbstractRNG = Random.default_rng() # get_hybridproblem_train_dataloader(rng, prob; scenario) diff --git a/src/ComponentArrayInterpreter.jl b/src/ComponentArrayInterpreter.jl index 46cd5f2..fe5bb2c 100644 --- a/src/ComponentArrayInterpreter.jl +++ b/src/ComponentArrayInterpreter.jl @@ -5,9 +5,12 @@ Interface for Type that implements - `as_ca(::AbstractArray, interpreter) -> ComponentArray` +- `ComponentArrays.getaxes(interpreter)` - `Base.length(interpreter) -> Int` When called on a vector, forwards to `as_ca`. + +There is a default implementation for Base.length based on ComponentArrays.getaxes. """ abstract type AbstractComponentArrayInterpreter end @@ -18,6 +21,11 @@ Returns a ComponentArray with underlying data `v`. """ function as_ca end +function Base.length(cai::AbstractComponentArrayInterpreter) + prod(_axis_length.(CA.getaxes(cai))) +end + + (interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter) """ @@ -36,9 +44,13 @@ function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {A CA.ComponentArray(vr, AX) end -function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX} - #sum(length, typeof(AX).parameters[1]) - prod(_axis_length.(AX)) +# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX} +# #sum(length, typeof(AX).parameters[1]) +# prod(_axis_length.(AX)) +# end + +function CA.getaxes(int::StaticComponentArrayInterpreter{AX}) where {AX} + AX end get_concrete(cai::StaticComponentArrayInterpreter) = cai @@ -63,10 +75,11 @@ function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter) CA.ComponentArray(vr, cai.axes) end -function Base.length(cai::ComponentArrayInterpreter) - prod(_axis_length.(cai.axes)) +function CA.getaxes(cai::ComponentArrayInterpreter) + cai.axes end + get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}() @@ -120,6 +133,10 @@ function ComponentArrayInterpreter( ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N ComponentArrayInterpreter(CA.getaxes(ca), n_dims) end +function ComponentArrayInterpreter( + cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N + ComponentArrayInterpreter(CA.getaxes(cai), n_dims) +end function ComponentArrayInterpreter( axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N} axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...) @@ -131,12 +148,17 @@ function ComponentArrayInterpreter( n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N ComponentArrayInterpreter(n_dims, CA.getaxes(ca)) end +function ComponentArrayInterpreter( + n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N + ComponentArrayInterpreter(n_dims, CA.getaxes(cai)) +end function ComponentArrayInterpreter( n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M} axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...) ComponentArrayInterpreter(axes_ext) end + # ambuiguity with two empty Tuples (edge prob that does not make sense) # Empty ComponentVector with no other array dimensions -> empty componentVector function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{}) @@ -156,6 +178,8 @@ _axis_length(::CA.FlatAxis) = 0 _axis_length(::CA.UnitRange) = 0 """ + flatten1(cv::CA.ComponentVector) + Removes the highest level of keys. Keeps the reference to the underlying data, but changes the axis. If first-level vector has no sub-names, an error (Aguement Error tuple must be non-empty) @@ -174,3 +198,16 @@ function flatten1(cv::CA.ComponentVector) CA.ComponentVector(cv, first(CA.getaxes(cv_new))) end end + + +""" + get_positions(cai::AbstractComponentArrayInterpreter) + +Create a NamedTuple of integer indices for each component. +Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector. +""" +function get_positions(cai::AbstractComponentArrayInterpreter) + @assert length(CA.getaxes(cai)) == 1 + cv = cai(1:length(cai)) + (; (k => cv[k] for k in keys(cv))... ) +end diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 18d1f4e..a5ef56a 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -6,10 +6,10 @@ const θall = vcat(θP, θM) const θP_nor0 = θP[(:K2,)] -const transP = elementwise(exp) -const transM = elementwise(exp) +# const transP = elementwise(exp) +# const transM = elementwise(exp) -const transMS = Stacked(elementwise(identity), elementwise(exp)) +# const transMS = Stacked(elementwise(identity), elementwise(exp)) const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM))) @@ -19,34 +19,61 @@ function f_doubleMM(θ::AbstractVector, x, intθ) θc = intθ(θ) #using ComponentArrays: ComponentArrays as CA #r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU - r0 = θc[:r0] - r1 = θc[:r1] - K1 = θc[:K1] - K2 = θc[:K2] + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + # vector will be repeated when broadcasted by a matrix + CA.getdata(θc[par]) + end + # r0 = θc[:r0] + # r1 = θc[:r1] + # K1 = θc[:K1] + # K2 = θc[:K2] y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) end return (y) end -function f_doubleMM(θ::AbstractMatrix, x::NTuple{N, AbstractMatrix}, intθ) where N +function f_doubleMM(θ::AbstractMatrix, x::NamedTuple, intθ::HVI.AbstractComponentArrayInterpreter) # provide θ for n_row sites # provide x.S1 as Matrix n_site x n_obs # extract parameters not depending on order, i.e whether they are in θP or θM θc = intθ(θ) - #using ComponentArrays: ComponentArrays as CA - #r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU @assert size(x.S1,1) == size(θ,1) # same number of sites @assert size(x.S1) == size(x.S2) # same number of observations #@assert length(x.s2 == n_obs) - r0 = θc[:,:r0] # vector will be repeated when broadcasted by a matrix - r1 = θc[:,:r1] - K1 = θc[:,:K1] - K2 = θc[:,:K2] + # problems on AD on GPU with indexing CA may be related to printing result, use ";" + (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par + # vector will be repeated when broadcasted by a matrix + CA.getdata(θc[:,par]) + end + # r0 = CA.getdata(θc[:,:r0]) # vector will be repeated when broadcasted by a matrix + # r1 = CA.getdata(θc[:,:r1]) + # K1 = CA.getdata(θc[:,:K1]) + # K2 = CA.getdata(θc[:,:K2]) y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) - return (y) + return (y) end - +# function f_doubleMM(θ::AbstractMatrix, x::NamedTuple, θpos::NamedTuple) +# # provide θ for n_row sites +# # provide x.S1 as Matrix n_site x n_obs +# # extract parameters not depending on order, i.e whether they are in θP or θM +# @assert size(x.S1,1) == size(θ,1) # same number of sites +# @assert size(x.S1) == size(x.S2) # same number of observations +# (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par +# # vector will be repeated when broadcasted by a matrix +# CA.getdata(θ[:,θpos[par]]) +# end +# # r0 = CA.getdata(θ[:,θpos.r0]) # vector will be repeated when broadcasted by a matrix +# # r1 = CA.getdata(θ[:,θpos.r1]) +# # K1 = CA.getdata(θ[:,θpos.K1]) +# # K2 = CA.getdata(θ[:,θpos.K2]) +# #y = r0 .+ r1 +# #y = x.S1 + x.S2 +# #y = (K1 .+ x.S1) +# #y = r1 .* x.S1 ./ (K1 .+ x.S1) +# y = r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2) +# return (y) +# end function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ()) if (:omit_r0 ∈ scenario) @@ -87,15 +114,17 @@ end function HVI.get_hybridproblem_transforms(prob::DoubleMMCase; scenario::NTuple = ()) + _θP, _θM = get_hybridproblem_par_templates(prob; scenario) if (:stackedMS ∈ scenario) - return ((; transP, transM = transMS)) + return (; transP = Stacked((HVI.Exp(),),(1:length(_θP),)), + transM = Stacked((identity,HVI.Exp(),),(1:1, 2:length(_θM),))) elseif (:transIdent ∈ scenario) # identity transformations, should AD on GPU - _θP, _θM = get_hybridproblem_par_templates(prob; scenario) return (; transP = Stacked((identity,),(1:length(_θP),)), transM = Stacked((identity,),(1:length(_θM),))) end - (; transP, transM) + (; transP = Stacked((HVI.Exp(),),(1:length(_θP),)), + transM = Stacked((HVI.Exp(),),(1:length(_θM),))) end # function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ()) @@ -109,15 +138,49 @@ end # (; n_covar, n_batch, n_θM, n_θP) # end +# function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (), +# gdev = :f_on_gpu ∈ scenario ? gpu_device() : identity, +# ) +# #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers +# par_templates = get_hybridproblem_par_templates(prob; scenario) +# intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) +# let θFix = gdev(θFix), intθ = get_concrete(intθ) +# function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) +# pred_sites = applyf(f_doubleMM, θMs, θP, θFix, xP, intθ) +# pred_global = eltype(pred_sites)[] +# return pred_global, pred_sites +# end +# end +# end + function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = (), + use_all_sites = false, gdev = :f_on_gpu ∈ scenario ? gpu_device() : identity, ) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + n_site_batch = use_all_sites ? n_site : n_batch #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers par_templates = get_hybridproblem_par_templates(prob; scenario) - intθ, θFix = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) - let θFix = gdev(θFix), intθ = get_concrete(intθ) - function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x) - pred_sites = applyf(f_doubleMM, θMs, θP, θFix, x, intθ) + intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) + θFix = repeat(θFix1', n_site_batch) + intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1)) + isP = repeat(axes(par_templates.θP,1)', n_site_batch) + let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP=isP, n_site_batch=n_site_batch + function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) + @assert length(xP) == n_site_batch + @assert size(θMs,2) == n_site_batch + # convert vector of tuples to tuple of matricesByRows + # need to supply xP as vectorOfTuples to work with DataLoader + # k = first(keys(xP[1])) + xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k + #stack(map(r -> r[k], xP))' + stack(map(r -> r[k], xP); dims = 1) + end)...) + #xPM = map(transpose, xPM1) + # make sure the same order of columns as in intθ + θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix + θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs)', θFixd) + pred_sites = f_doubleMM(θ, xPM, intθ)' pred_global = eltype(pred_sites)[] return pred_global, pred_sites end @@ -130,6 +193,7 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = () end end + function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ()) neg_logden_indep_normal end @@ -146,25 +210,28 @@ const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0] # const xP_S2 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0] HVI.get_hybridproblem_n_covar(prob::DoubleMMCase; scenario) = 5 -function HVI.get_hybridproblem_n_site(prob::DoubleMMCase; scenario) +function HVI.get_hybridproblem_n_site_and_batch(prob::DoubleMMCase; scenario) + n_batch = 20 + n_site = 800 if (:few_sites ∈ scenario) - return(100) + n_site = 100 elseif (:sites20 ∈ scenario) - return(20) + n_site = 20 end - 800 + (n_site, n_batch) end function HVI.get_hybridproblem_train_dataloader(prob::DoubleMMCase; scenario = (), - n_batch, rng::AbstractRNG = StableRNG(111), kwargs... + rng::AbstractRNG = StableRNG(111), kwargs... ) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) construct_dataloader_from_synthetic(rng, prob; scenario, n_batch, kwargs...) end function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; scenario = ()) n_covar_pc = 2 - n_site = get_hybridproblem_n_site(prob; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) n_covar = get_hybridproblem_n_covar(prob; scenario) n_θM = length(θM) FloatType = get_hybridproblem_float_type(prob; scenario) @@ -174,7 +241,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,)) # normalize to be distributed around the prescribed true values θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1))) - f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity) + f = get_hybridproblem_PBmodel(prob; scenario, gdev=identity, use_all_sites = true) xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site) θP = par_templates.θP y_global_true, y_true = f(θP, θMs_true, xP) @@ -186,7 +253,6 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o (; xM, - n_site, θP_true = θP, θMs_true, xP, diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index 625c3e3..5be6a91 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -1,70 +1,71 @@ struct HybridProblem <: AbstractHybridProblem θP::Any θM::Any - f::Any + f_batch::Any + f_allsites::Any g::Any ϕg::Any ϕunc::Any priors::Any py::Any - transP::Any transM::Any + transP::Any cor_ends::Any # = (P=(1,),M=(1,)) - get_train_loader::Any + train_dataloader::Any n_covar::Int n_site::Int + n_batch::Int pbm_covars::NTuple - # inner constructor to constrain the types + #inner constructor to constrain the types function HybridProblem( θP::CA.ComponentVector, θM::CA.ComponentVector, g::AbstractModelApplicator, ϕg::AbstractVector, ϕunc::CA.ComponentVector, - f::Function, + f_batch::Function, + f_allsites::Function, priors::AbstractDict, py::Function, transM::Union{Function, Bijectors.Transform}, transP::Union{Function, Bijectors.Transform}, # return a function that constructs the trainloader based on n_batch - get_train_loader::Function, + train_dataloader::MLUtils.DataLoader, n_covar::Int, n_site::Int, + n_batch::Int, cor_ends::NamedTuple = (P = [length(θP)], M = [length(θM)]), - pbm_covars::NTuple{N,Symbol} = (), + pbm_covars::NTuple{N,Symbol} = () ) where N new( - θP, θM, f, g, ϕg, ϕunc, priors, py, transM, transP, cor_ends, get_train_loader, - n_covar, n_site, pbm_covars) + θP, θM, f_batch, f_allsites, g, ϕg, ϕunc, priors, py, transM, transP, cor_ends, + train_dataloader, n_covar, n_site, n_batch, pbm_covars) end end function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, # note no ϕg argument and g_chain unconstrained - g_chain, f::Function, + g_chain, f_batch::Function, args...; rng = Random.default_rng(), kwargs...) # dispatches on type of g_chain g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM)) - HybridProblem(θP, θM, g, ϕg, f, args...; kwargs...) + HybridProblem(θP, θM, g, ϕg, f_batch, args...; kwargs...) end function HybridProblem(prob::AbstractHybridProblem; scenario = ()) (; θP, θM) = get_hybridproblem_par_templates(prob; scenario) g, ϕg = get_hybridproblem_MLapplicator(prob; scenario) ϕunc = get_hybridproblem_ϕunc(prob; scenario) - f = get_hybridproblem_PBmodel(prob; scenario) + f_batch = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + f_allsites = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) py = get_hybridproblem_neg_logden_obs(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) - get_train_loader = let prob = prob, scenario = scenario - function inner_get_train_loader(;kwargs...) - get_hybridproblem_train_dataloader(prob; scenario, kwargs...) - end - end + train_dataloader = get_hybridproblem_train_dataloader(prob; scenario) cor_ends = get_hybridproblem_cor_ends(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) priors = get_hybridproblem_priors(prob; scenario) n_covar = get_hybridproblem_n_covar(prob; scenario) - n_site = get_hybridproblem_n_site(prob; scenario) - HybridProblem(θP, θM, g, ϕg, ϕunc, f, priors, py, transP, transM, get_train_loader, - n_covar, n_site, cor_ends, pbm_covars) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + HybridProblem(θP, θM, g, ϕg, ϕunc, f_batch, f_allsites, priors, py, transM, transP, train_dataloader, + n_covar, n_site, n_batch, cor_ends, pbm_covars) end function update(prob::HybridProblem; @@ -73,19 +74,23 @@ function update(prob::HybridProblem; g::AbstractModelApplicator = prob.g, ϕg::AbstractVector = prob.ϕg, ϕunc::CA.ComponentVector = prob.ϕunc, - f::Function = prob.f, + f_batch::Function = prob.f_batch, + f_allsites::Function = prob.f_allsites, priors::AbstractDict = prob.priors, py::Function = prob.py, - transM::Union{Function, Bijectors.Transform} = prob.transM, - transP::Union{Function, Bijectors.Transform} = prob.transP, + # transM::Union{Function, Bijectors.Transform} = prob.transM, + # transP::Union{Function, Bijectors.Transform} = prob.transP, + transM = prob.transM, + transP = prob.transP, cor_ends::NamedTuple = prob.cor_ends, pbm_covars::NTuple{N,Symbol} = prob.pbm_covars, - get_train_loader::Function = prob.get_train_loader, + train_dataloader::MLUtils.DataLoader = prob.train_dataloader, n_covar::Integer = prob.n_covar, - n_site::Integer = prob.n_site + n_site::Integer = prob.n_site, + n_batch::Integer = prob.n_batch, ) where N - HybridProblem(θP, θM, g, ϕg, ϕunc, f, priors, py, transP, transM, get_train_loader, - n_covar, n_site, cor_ends, pbm_covars) + HybridProblem(θP, θM, g, ϕg, ϕunc, f_batch, f_allsites, priors, py, transM, transP, + train_dataloader, n_covar, n_site, n_batch, cor_ends, pbm_covars) end function get_hybridproblem_par_templates(prob::HybridProblem; scenario::NTuple = ()) @@ -110,16 +115,16 @@ end # (; n_covar=prob.n_covar, n_batch=prob.n_batch, n_θM, n_θP) # end -function get_hybridproblem_PBmodel(prob::HybridProblem; scenario::NTuple = ()) - prob.f +function get_hybridproblem_PBmodel(prob::HybridProblem; scenario::NTuple = (), use_all_sites=false) + use_all_sites ? prob.f_allsites : prob.f_batch end function get_hybridproblem_MLapplicator(prob::HybridProblem; scenario::NTuple = ()) prob.g, prob.ϕg end -function get_hybridproblem_train_dataloader(prob::HybridProblem; kwargs...) - return prob.get_train_loader(;kwargs...) +function get_hybridproblem_train_dataloader(prob::HybridProblem; scenario = ()) + prob.train_dataloader end function get_hybridproblem_cor_ends(prob::HybridProblem; scenario = ()) @@ -131,8 +136,8 @@ end function get_hybridproblem_n_covar(prob::HybridProblem; scenario = ()) prob.n_covar end -function get_hybridproblem_n_site(prob::HybridProblem; scenario = ()) - prob.n_site +function get_hybridproblem_n_site_and_batch(prob::HybridProblem; scenario = ()) + prob.n_site, prob.n_batch end function get_hybridproblem_priors(prob::HybridProblem; scenario = ()) diff --git a/src/HybridSolver.jl b/src/HybridSolver.jl index 6e00147..60eb1dc 100644 --- a/src/HybridSolver.jl +++ b/src/HybridSolver.jl @@ -2,11 +2,9 @@ abstract type AbstractHybridSolver end struct HybridPointSolver{A} <: AbstractHybridSolver alg::A - n_batch::Int end -HybridPointSolver(; alg, n_batch = 10) = HybridPointSolver(alg, n_batch) -#HybridPointSolver(; alg = Adam(0.02), n_batch = 10) = HybridPointSolver(alg,n_batch) +HybridPointSolver(; alg) = HybridPointSolver(alg) function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolver; scenario, rng = Random.default_rng(), @@ -21,32 +19,32 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve ϕg = 1:length(ϕg0), ϕP = par_templates.θP)) #ϕ0_cpu = vcat(ϕg0, par_templates.θP .* FT(0.9)) # slightly disturb θP_true ϕ0_cpu = vcat(ϕg0, apply_preserve_axes(inverse(transP),par_templates.θP)) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ0_cpu) g_dev = gdev(g) + train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario, gdev) else ϕ0_dev = ϕ0_cpu g_dev = g + train_loader_dev = train_loader end - train_loader = get_hybridproblem_train_dataloader( - prob; scenario, n_batch = solver.n_batch) - f = get_hybridproblem_PBmodel(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) y_global_o = FT[] # TODO pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) #intP = ComponentArrayInterpreter(par_templates.θP) loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; cdev, pbm_covars) # call loss function once - l1 = loss_gf(ϕ0_dev, first(train_loader)...)[1] + l1 = loss_gf(ϕ0_dev, first(train_loader_dev)...)[1] # and gradient - # xMg, xP, y_o, y_unc = first(train_loader) + # xMg, xP, y_o, y_unc = first(train_loader_dev) # gr1 = Zygote.gradient( # p -> loss_gf(p, xMg, xP, y_o, y_unc)[1], # ϕ0_dev) - # data1 = first(train_loader) - # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev) +165 # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev) optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], Optimization.AutoZygote()) - optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader) + optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader_dev) res = Optimization.solve(optprob, solver.alg; kwargs...) ϕ = intϕ(res.u) θP = cpu_ca(apply_preserve_axes(transP, cpu_ca(ϕ).ϕP)) @@ -56,19 +54,17 @@ end struct HybridPosteriorSolver{A} <: AbstractHybridSolver alg::A - n_batch::Int n_MC::Int n_MC_cap::Int end -function HybridPosteriorSolver(; alg, n_batch = 10, n_MC = 12, n_MC_cap = n_MC) - HybridPosteriorSolver(alg, n_batch, n_MC, n_MC_cap) +function HybridPosteriorSolver(; alg, n_MC = 12, n_MC_cap = n_MC) + HybridPosteriorSolver(alg, n_MC, n_MC_cap) end function update(solver::HybridPosteriorSolver; alg = solver.alg, - n_batch = solver.n_batch, n_MC = solver.n_MC, n_MC_cap = n_MC) - HybridPosteriorSolver(alg, n_batch, n_MC, n_MC_cap) + HybridPosteriorSolver(alg, n_MC, n_MC_cap) end function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorSolver; @@ -84,34 +80,37 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, solver.n_batch; transP, transM, ϕunc0) + θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ) g_dev = gdev(g) # zygote fails if gdev is a CPUDevice, although should be non-op + train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario, gdev) else ϕ0_dev = ϕ g_dev = g + train_loader_dev = train_loader end - train_loader = get_hybridproblem_train_dataloader(prob; scenario, solver.n_batch) - f = get_hybridproblem_PBmodel(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) py = get_hybridproblem_neg_logden_obs(prob; scenario) priors_θ_mean = construct_priors_θ_mean( prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM, transP; - scenario, get_ca_int_PMs, cdev, pbm_covars) + scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) y_global_o = Float32[] # TODO loss_elbo = get_loss_elbo( g_dev, transPMs_batch, f, py, y_global_o, interpreters; solver.n_MC, solver.n_MC_cap, cor_ends, priors_θ_mean, cdev, pbm_covars, θP) # test loss function once - l0 = loss_elbo(ϕ0_dev, rng, first(train_loader)...) + l0 = loss_elbo(ϕ0_dev, rng, first(train_loader_dev)...) optf = Optimization.OptimizationFunction((ϕ, data) -> loss_elbo(ϕ, rng, data...)[1], Optimization.AutoZygote()) - optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader) + optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader_dev) res = Optimization.solve(optprob, solver.alg; kwargs...) ϕc = interpreters.μP_ϕg_unc(res.u) θP = cpu_ca(apply_preserve_axes(transP, ϕc.μP)) - probo = update(prob; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc); + probo = update(prob; ϕg = cpu_ca(ϕc).ϕg, θP = θP, ϕunc = cpu_ca(ϕc).unc); (; ϕ = ϕc, θP, resopt = res, interpreters, probo) end @@ -157,14 +156,15 @@ end """ Compute the components of the elbo for given initial conditions of the problems -for the first batch of the trainloader, whose `n_batch` defaults to all sites. +for the first batch of the trainloader. """ function compute_elbo_components( prob::AbstractHybridProblem, solver::HybridPosteriorSolver; scenario, rng = Random.default_rng(), gdev = gpu_device(), θmean_quant = 0.0, - n_batch = get_hybridproblem_n_site(prob; scenario), + use_all_sites = false, kwargs...) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates cor_ends = get_hybridproblem_cor_ends(prob; scenario) @@ -172,21 +172,24 @@ function compute_elbo_components( ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) + θP, θM, cor_ends, ϕg0, n_batch; transP, transM, ϕunc0) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdev isa MLDataDevices.AbstractGPUDevice ϕ0_dev = gdev(ϕ) g_dev = gdev(g) # zygote fails if gdev is a CPUDevice, although should be non-op + train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario, gdev) else ϕ0_dev = ϕ g_dev = g + train_loader_dev = train_loader end - train_loader = get_hybridproblem_train_dataloader(prob; scenario, n_batch) - f = get_hybridproblem_PBmodel(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites) py = get_hybridproblem_neg_logden_obs(prob; scenario) priors_θ_mean = construct_priors_θ_mean( prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM; - scenario, get_ca_int_PMs) - xM, xP, y_o, y_unc, i_sites = first(train_loader) + scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) + # TODO replace train_loader.data by proper function that pulls all the data + xM, xP, y_o, y_unc, i_sites = use_all_sites ? train_loader_dev.data : first(train_loader_dev) neg_elbo_gtf_components( rng, ϕ0_dev, g_dev, transPMs_batch, f, py, xM, xP, y_o, y_unc, i_sites, interpreters; solver.n_MC, solver.n_MC_cap, cor_ends, priors_θ_mean) @@ -197,17 +200,21 @@ In order to let mean of θ stay close to initial point parameter estimates construct a prior on mean θ to a Normal around initial prediction. """ function construct_priors_θ_mean(prob, ϕg, keysθM, θP, θmean_quant, g_dev, transM, transP; - scenario, get_ca_int_PMs, cdev, pbm_covars) + scenario, get_ca_int_PMs, gdev, cdev, pbm_covars) iszero(θmean_quant) ? [] : begin - n_site = get_hybridproblem_n_site(prob; scenario) - all_loader = get_hybridproblem_train_dataloader(prob; scenario, n_batch = n_site) - xM_all = first(all_loader)[1] - #Main.@infiltrate_main + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + # all_loader = MLUtils.DataLoader( + # get_hybridproblem_train_dataloader(prob; scenario).data, batchsize = n_site) + # xM_all = first(all_loader)[1] + is_gpu = :use_gpu ∈ scenario + xM_all_cpu = get_hybridproblem_train_dataloader(prob; scenario).data[1] + xM_all = is_gpu ? gdev(xM_all_cpu) : xM_all_cpu ζP = apply_preserve_axes(inverse(transP), θP) pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars) xMP_all = _append_each_covars(xM_all, CA.getdata(ζP), pbm_covar_indices) - θMs = gtrans(g_dev, transM, xMP_all, CA.getdata(ϕg); cdev) + #Main.@infiltrate_main + θMs = gtrans(g_dev, transM, xMP_all, CA.getdata(ϕg); cdev = cpu_device()) priors_dict = get_hybridproblem_priors(prob; scenario) priorsP = [priors_dict[k] for k in keys(θP)] priors_θP_mean = map(priorsP, θP) do priorsP, θPi diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index efb2f3f..095e19d 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -20,7 +20,10 @@ using Distributions, DistributionFits using StaticArrays: StaticArrays as SA using Functors -export ComponentArrayInterpreter, flatten1, get_concrete +#export Exp +include("bijectors_utils.jl") + +export ComponentArrayInterpreter, flatten1, get_concrete, get_positions include("ComponentArrayInterpreter.jl") export AbstractModelApplicator, construct_ChainsApplicator @@ -38,13 +41,14 @@ export AbstractHybridProblem, get_hybridproblem_MLapplicator, get_hybridproblem_ get_hybridproblem_train_dataloader, get_hybridproblem_neg_logden_obs, get_hybridproblem_n_covar, - get_hybridproblem_n_site, + get_hybridproblem_n_site_and_batch, get_hybridproblem_cor_ends, get_hybridproblem_priors, get_hybridproblem_pbmpar_covars, #update, gen_cov_pred, construct_dataloader_from_synthetic, + gdev_hybridproblem_dataloader, setup_PBMpar_interpreter include("AbstractHybridProblem.jl") diff --git a/src/bijectors_utils.jl b/src/bijectors_utils.jl new file mode 100644 index 0000000..3cc70ee --- /dev/null +++ b/src/bijectors_utils.jl @@ -0,0 +1,21 @@ +struct Exp <: Bijector +end + +#Functors.@functor Exp +Bijectors.transform(b::Exp, x) = exp.(x) # note the broadcast +Bijectors.transform(ib::Inverse{<:Exp}, y) = log.(y) + +# `logabsdetjac` +Bijectors.logabsdetjac(b::Exp, x) = sum(x) + +`with_logabsdet_jacobian` +function Bijectors.with_logabsdet_jacobian(b::Exp, x) + return exp.(x), sum(x) +end +# function Bijectors.with_logabsdet_jacobian(ib::Inverse{<:Exp}, y) +# x = transform(ib, y) +# return x, -logabsdetjac(inverse(ib), x) +# end + + +Bijectors.is_monotonically_increasing(::Exp) = true diff --git a/src/elbo.jl b/src/elbo.jl index bc9b3ad..dfa69ea 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -63,6 +63,8 @@ function neg_elbo_gtf_components(rng, ϕ::AbstractVector, g, transPMs, f, py, #nLmean_θP = map((d, θi) -> -logpdf(d, θi), CA.getdata(priors_θ_mean.P), mean_θP) #workaround for Zygote failing on `priors_θ_mean.P` iθ = CA.ComponentArray(1:length(priors_θ_mean), CA.getaxes(priors_θ_mean)) + # need to apply different dist to each entry in θP and mean_θMs -> @allowscalar + # but does not work with Zygote nLmean_θP = map((d, θi) -> -logpdf(d, θi), priors_θ_mean[CA.getdata(iθ.P)], mean_θP) θMss = map(θ -> interpreters.PMs(θ).Ms, θs) |> stack mean_θMs = mean(θMss; dims = (3))[:, :, 1] @@ -131,9 +133,10 @@ Prediction function for hybrid model. Returns an NamedTuple with entries - `y`: Array `(n_obs, n_site, n_sample_pred)` of model predictions. """ function predict_gf(rng, prob::AbstractHybridProblem; scenario, kwargs...) - n_batch = get_hybridproblem_n_site(prob; scenario) - data = first(get_hybridproblem_train_dataloader(prob; scenario, n_batch)) - predict_gf(rng, prob, data[1], data[2]; scenario, kwargs...) + dl = get_hybridproblem_train_dataloader(prob; scenario) + dl_dev = gdev_hybridproblem_dataloader(dl; scenario) + xM, xP = dl_dev.data[1:2] + predict_gf(rng, prob, xM, xP; scenario, kwargs...) end function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; scenario, @@ -141,19 +144,22 @@ function predict_gf(rng, prob::AbstractHybridProblem, xM::AbstractMatrix, xP; gdev = :use_gpu ∈ scenario ? gpu_device() : identity, cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity ) - n_site = length(xP) - @assert size(xM, 2) == n_site + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + is_predict_batch = (n_batch == length(xP)) + n_site_pred = is_predict_batch ? n_batch : n_site + @assert length(xP) == n_site_pred + @assert size(xM, 2) == n_site_pred + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = !is_predict_batch) par_templates = get_hybridproblem_par_templates(prob; scenario) (; θP, θM) = par_templates cor_ends = get_hybridproblem_cor_ends(prob; scenario) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) ϕunc0 = get_hybridproblem_ϕunc(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) - f = get_hybridproblem_PBmodel(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars) (; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params( - θP, θM, cor_ends, ϕg0, n_site; transP, transM, ϕunc0) + θP, θM, cor_ends, ϕg0, n_site_pred; transP, transM, ϕunc0) g_dev, ϕ_dev = gdev(g), gdev(ϕ) predict_gf(rng, g_dev, f, ϕ_dev, xM, xP, interpreters; get_transPMs, get_ca_int_PMs, n_sample_pred, cdev, cor_ends, pbm_covar_indices) diff --git a/src/gf.jl b/src/gf.jl index 81e129a..58868b5 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -19,13 +19,24 @@ end """ composition f ∘ transM ∘ g: mechanistic model after machine learning parameter prediction """ -function gf(prob::AbstractHybridProblem, xM, xP, args...; +function gf(prob::AbstractHybridProblem, args...; scenario = (), kwargs...) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) + train_loader_dev = gdev_hybridproblem_dataloader(train_loader; scenario) + xM, xP = train_loader_dev.data[1:2] + gf(prob, xM, xP, args...; kwargs...) +end +function gf(prob::AbstractHybridProblem, xM::AbstractMatrix, xP::AbstractVector, args...; scenario = (), gdev = :use_gpu ∈ scenario ? gpu_device() : identity, cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity, kwargs...) g, ϕg = get_hybridproblem_MLapplicator(prob; scenario) - f = get_hybridproblem_PBmodel(prob; scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + is_predict_batch = (n_batch == length(xP)) + n_site_pred = is_predict_batch ? n_batch : n_site + @assert length(xP) == n_site_pred + @assert size(xM, 2) == n_site_pred + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = !is_predict_batch) (; θP, θM) = get_hybridproblem_par_templates(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) intP = ComponentArrayInterpreter(θP) diff --git a/src/logden_normal.jl b/src/logden_normal.jl index 50a8509..e25fafb 100644 --- a/src/logden_normal.jl +++ b/src/logden_normal.jl @@ -13,7 +13,8 @@ a low uncertainty estimate and means closer to the observations to help an initial fit. The obtained parameters then can be used as starting values for a the proper fit with `σfac=1.0`. """ -function neg_logden_indep_normal(obs::AbstractArray, μ::AbstractArray, logσ2::AbstractArray; σfac=1.0) +function neg_logden_indep_normal(obs::AbstractArray, μ::AbstractArray, logσ2::AbstractArray{ET}; + σfac=one(ET)) where ET # log of independent Normal distributions # estimate independent uncertainty of each θM, rather than full covariance #nlogL = sum(σfac .* log.(σs) .+ 1 / 2 .* abs2.((obs .- μ) ./ σs)) diff --git a/test/runtests.jl b/test/runtests.jl index 45f7450..2c87aa2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,8 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml @time begin if GROUP == "All" || GROUP == "Basic" + #@safetestset "test" include("test/test_bijectors_utils.jl") + @time @safetestset "test_bijectors_utils" include("test_bijectors_utils.jl") #@safetestset "test" include("test/test_ComponentArrayInterpreter.jl") @time @safetestset "test_ComponentArrayInterpreter" include("test_ComponentArrayInterpreter.jl") #@safetestset "test" include("test/test_ModelApplicator.jl") diff --git a/test/test_ComponentArrayInterpreter.jl b/test/test_ComponentArrayInterpreter.jl index f392dae..f3b8fda 100644 --- a/test/test_ComponentArrayInterpreter.jl +++ b/test/test_ComponentArrayInterpreter.jl @@ -57,6 +57,8 @@ end; testm(mm) mmc = get_concrete(mm) testm(mmc) + mmi = ComponentArrayInterpreter(mv, (n_col,)) # construct on interpreter itself + testm(mmi) # n_z = 3 mm = ComponentArrayInterpreter(cv, (n_col, n_z)) @@ -67,6 +69,8 @@ end; end testm(mm) testm(get_concrete(mm)) + mmi = ComponentArrayInterpreter(mv, (n_col, n_z)) # construct on interpreter itself + testm(mmi) # n_row = 3 mm = ComponentArrayInterpreter((n_row,), cv) @@ -77,6 +81,8 @@ end; end testm(mm) testm(get_concrete(mm)) + mm = ComponentArrayInterpreter((n_row,), mv) # construct on interpreter itself + testm(mmi) end; @testset "empty ComponentVector" begin diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 36939ea..e9d717f 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -1,6 +1,6 @@ using Test using HybridVariationalInference -using HybridVariationalInference: HybridVariationalInference as HVI +using HybridVariationalInference: HybridVariationalInference as CP using StableRNGs using Random using Statistics @@ -46,8 +46,10 @@ construct_problem = (;scenario=(:default,)) -> begin end n_out = length(θM) rng = StableRNG(111) + # n_batch = 10 + n_site, n_batch = get_hybridproblem_n_site_and_batch(CP.DoubleMM.DoubleMMCase(); scenario) # dependency on DeoubleMMCase -> take care of changes in covariates - (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase()) n_covar = size(xM,1) n_input = (:covarK2 ∈ scenario) ? n_covar +1 : n_covar @@ -62,13 +64,13 @@ construct_problem = (;scenario=(:default,)) -> begin # g, ϕg = construct_SimpleChainsApplicator(g_chain) # py = neg_logden_indep_normal - n_batch = 10 i_sites = 1:n_site - get_train_loader = let xM = xM, xP = xP, y_o = y_o, y_unc = y_unc, i_sites = i_sites - function inner_get_train_loader(; n_batch, kwargs...) - MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) - end - end + # get_train_loader = let xM = xM, xP = xP, y_o = y_o, y_unc = y_unc, i_sites = i_sites + # function inner_get_train_loader(; n_batch, kwargs...) + # MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) + # end + # end + train_dataloader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize=n_batch, partial=false) θall = vcat(θP, θM) priors_dict = Dict{Symbol, Distribution}(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) priors_dict[:r1] = fit(Normal, θall.r1, qp_uu(3 * θall.r1)) # not transformed to log-scale @@ -79,8 +81,10 @@ construct_problem = (;scenario=(:default,)) -> begin #g_chain_scaled = app ϕunc0 = init_hybrid_ϕunc(cor_ends, zero(FT)) pbm_covars = (:covarK2 ∈ scenario) ? (:K2,) : () - HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, f_doubleMM_with_global, priors_dict, py, - transM, transP, get_train_loader, n_covar, n_site, cor_ends, pbm_covars) + HybridProblem(θP, θM, g_chain_scaled, ϕg0, ϕunc0, + f_doubleMM_with_global, f_doubleMM_with_global, priors_dict, py, + transM, transP, train_dataloader, n_covar, n_site, n_batch, + cor_ends, pbm_covars) end test_without_flux = (scenario) -> begin @@ -104,9 +108,10 @@ test_without_flux = (scenario) -> begin #----------- fit g and θP to y_o rng = StableRNG(111) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) - train_loader = get_hybridproblem_train_dataloader(prob; n_batch=10, scenario) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) (xM, xP, y_o, y_unc, i_sites) = first(train_loader) - f = get_hybridproblem_PBmodel(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) par_templates = get_hybridproblem_par_templates(prob; scenario) #f(par_templates.θP, hcat(par_templates.θM, par_templates.θM), xP[1:2]) (; transM, transP) = get_hybridproblem_transforms(prob; scenario) @@ -120,7 +125,7 @@ test_without_flux = (scenario) -> begin y_global_o = Float64[] loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; pbm_covars) l1 = loss_gf(p0, first(train_loader)...) - tld = train_loader.data + tld = first(train_loader) gr = Zygote.gradient(p -> loss_gf(p, tld...)[1], CA.getdata(p0)) @test gr[1] isa Vector @@ -147,15 +152,15 @@ using GPUArraysCore import Flux gdev = gpu_device() -#methods(HVI.vec2uutri) +#methods(CP.vec2uutri) test_with_flux = (scenario) -> begin prob = probc = construct_problem(;scenario); @testset "HybridPointSolver" begin rng = StableRNG(111) - solver = HybridPointSolver(; alg=Adam(0.02), n_batch=11) - (; ϕ, resopt) = solve(prob, solver; scenario, rng, + solver = HybridPointSolver(; alg=Adam(0.02)) + (; ϕ, resopt, probo) = solve(prob, solver; scenario, rng, #callback = callback_loss(100), maxiters = 1200 #maxiters = 1200 #maxiters = 20 @@ -164,13 +169,17 @@ test_with_flux = (scenario) -> begin #gpu_handler = NullGPUDataHandler ) (; θP) = get_hybridproblem_par_templates(prob; scenario) - @test ϕ.ϕP.r0 < 1.5 * θP.r0 + θPo = (() -> begin + (; θP) = get_hybridproblem_par_templates(probo; scenario); + θP + end)() + @test θPo.r0 < 1.5 * θP.r0 @test ϕ.ϕP.K2 < 1.5 * log(θP.K2) end; @testset "HybridPosteriorSolver" begin rng = StableRNG(111) - solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3) + solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; ϕ, θP, resopt) = solve(prob, solver; scenario, rng, #callback = callback_loss(100), maxiters = 1200, #maxiters = 20 # too small so that it yields error @@ -189,9 +198,12 @@ test_with_flux = (scenario) -> begin @testset "HybridPosteriorSolver gpu" begin scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0) rng = StableRNG(111) - prob = probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf) - solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3) - n_batches_in_epoch = get_hybridproblem_n_site(prob; scenario) ÷ solver.n_batch + # here using DoubleMMCase() directly rather than construct_problem + #(;transP, transM) = get_hybridproblem_transforms(DoubleMM.DoubleMMCase(); scenario = scenf) + prob = probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf); + solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf) + n_batches_in_epoch = n_site ÷ n_batch (; ϕ, θP, resopt) = solve(prob, solver; scenario = scenf, rng, maxiters = 37, # smallest value by trial and error #maxiters = 20 # too small so that it yields error @@ -200,14 +212,20 @@ test_with_flux = (scenario) -> begin @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector #@test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations in test -> may fail # - solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3) + solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, + maxiters = 37, + ); + @test cdev(ϕ.unc.ρsM)[1] > 0 + @test probo.ϕunc == cdev(ϕ.unc) test_correlation = () -> begin - n_epoch = 100 # requires + n_epoch = 20 # requires (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, maxiters = n_batches_in_epoch * n_epoch, callback = callback_loss(n_batches_in_epoch*5) ); @test cdev(ϕ.unc.ρsM)[1] > 0 + @test probo.ϕunc == cdev(ϕ.unc) # predict using problem and its associated dataloader (; θ, y, entropy_ζ) = predict_gf(rng, probo; scenario = scenf, n_sample_pred = 200); mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1]) @@ -221,24 +239,23 @@ test_with_flux = (scenario) -> begin end end; - # does not work with general Bijector: - # @testset "HybridPosteriorSolver also f on gpu" begin - # scenario = (:use_Flux, :use_gpu, :omit_r0, :f_on_gpu) - # rng = StableRNG(111) - # probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario) - # prob = HVI.update(probg) - # #prob = HVI.update(probg, transM = identity, transP = identity) - # solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3) - # n_batches_in_epoch = get_hybridproblem_n_site(prob; scenario) ÷ solver.n_batch - # (; ϕ, θP, resopt) = solve(prob, solver; scenario, rng, - # maxiters = 37, # smallest value by trial and error - # #maxiters = 20 # too small so that it yields error - # #θmean_quant = 0.01, # TODO make possible on gpu - # cdev = identity # do not move ζ to cpu # TODO infer in solve from scenario - # ) - # @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector - # @test cdev(ϕ.unc.ρsM)[1] > 0 - # end; + @testset "HybridPosteriorSolver also f on gpu" begin + scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu) + rng = StableRNG(111) + probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf); + #prob = CP.update(probg, transM = identity, transP = identity); + solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf) + n_batches_in_epoch = n_site ÷ n_batch + (; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf, rng, + maxiters = 37, # smallest value by trial and error + #maxiters = 20 # too small so that it yields error + #θmean_quant = 0.01, # TODO make possible on gpu + cdev = identity # do not move ζ to cpu # TODO infer in solve from scenario + ); + @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector + # @test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations + end; end # if gdev isa MLDataDevices.AbstractGPUDevice end # test_with flux diff --git a/test/test_bijectors_utils.jl b/test/test_bijectors_utils.jl new file mode 100644 index 0000000..351ee86 --- /dev/null +++ b/test/test_bijectors_utils.jl @@ -0,0 +1,73 @@ +using Test +using HybridVariationalInference +using HybridVariationalInference: HybridVariationalInference as CP + +using Bijectors + +using MLDataDevices +import CUDA, cuDNN +using Zygote + + +x = [0.1, 0.2, 0.3, 0.4] +gdev = gpu_device() +cdev = cpu_device() + +function trans(x, b) + y, logjac = Bijectors.with_logabsdet_jacobian(b, x) + sum(y .+ logjac) +end + +b2 = elementwise(exp) +b2s = Stacked((b2,b2),(1:3,4:4)) +b3 = HybridVariationalInference.Exp() +b3s = Stacked((b3,b3), (1:3,4:4)) +#b3s = Stacked((b3,),(1:4,)) + + +y = trans(x, b2) +dy = Zygote.gradient(x -> trans(x,b2), x) + + +@testset "elementwise exp" begin + ys = trans(x,b2s) + @test ys == y + Zygote.gradient(x -> trans(x,b2s), x) +end; + +@testset "Exp" begin + y1 = b3(x) + y2 = b3s(x) + @test all(inverse(b3)(y2) .≈ x) + @test all(inverse(b3s)(y2) .≈ x) + ye = trans(x, b3) + dye = Zygote.gradient(x -> trans(x,b3), x) + @test ye == y + @test dye == dy + ys = trans(x,b3s) + dys = Zygote.gradient(x -> trans(x,b2s), x) + @test dys == dy +end; + + +if gdev isa MLDataDevices.AbstractGPUDevice + xd = gdev(x) + @testset "elementwise exp gpu" begin + ys = trans(xd,b2) + @test ys ≈ y + @test_broken Zygote.gradient(x -> trans(x,b2), xd) + @test_broken Zygote.gradient(x -> trans(x,b2s), xd) + end; + + @testset "Exp" begin + ye = trans(xd, b3) + dye = Zygote.gradient(x -> trans(x,b3), xd) + @test ye ≈ y + @test all(cdev(dye) .≈ dy) + ys = trans(xd,b3s) + dys = Zygote.gradient(x -> trans(x,b3s), xd) + @test ys ≈ y + @test all(cdev(dys) .≈ dy) + end; +end + diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 6b42803..e3e84ad 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -14,7 +14,14 @@ import Zygote using OptimizationOptimisers using MLDataDevices -const prob = DoubleMM.DoubleMMCase() +using CUDA: CUDA +using Flux +using GPUArraysCore + +gdev = gpu_device() +cdev = cpu_device() + +prob = DoubleMM.DoubleMMCase() scenario = (:default,) #using Flux #scenario = (:use_Flux,) @@ -28,16 +35,20 @@ par_templates = get_hybridproblem_par_templates(prob; scenario) @test quantile(priors[:K2], 0.95) ≈ θall.K2 * 3 # fitted in f_doubleMM end -rng = StableRNG(111) -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc +rng = StableRNG(111) # make sure to be the same as when constructing train_dataloader +(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc ) = gen_hybridproblem_synthetic(rng, prob; scenario); +n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) i_sites = 1:n_site +fneglogden = get_hybridproblem_neg_logden_obs(prob; scenario) @testset "gen_hybridproblem_synthetic" begin @test isapprox( vec(mean(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM), rtol = 0.02) @test isapprox(vec(std(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM) .* 0.1, rtol = 0.02) + @test size(xP) == (n_site,) + @test size(y_o) == (8, n_site) # test same results for same rng rng2 = StableRNG(111) @@ -45,6 +56,76 @@ i_sites = 1:n_site @test gen2.y_o == y_o end +@testset "f_doubleMM_Matrix" begin + is = repeat(axes(θP_true, 1)', n_site) + θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) + xPM = map(xP1s -> repeat(xP1s', n_site), xP[1]) + #θ = hcat(θP_true[is], θMs_true') + intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) + #θpos = get_positions(intθ1) + intθ = get_concrete(ComponentArrayInterpreter((n_site,), intθ1)) + fy = (θvec, xPM) -> begin + θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) + y = HVI.DoubleMM.f_doubleMM(θ, xPM, intθ) + #y = HVI.DoubleMM.f_doubleMM(θ, xPM, θpos) + end + y = fy(θvec, xPM) + y_exp = applyf(HVI.DoubleMM.f_doubleMM, θMs_true, θP_true, + Vector{eltype(θP_true)}(undef, 0), xP, intθ1) + @test y == y_exp' + ygrad = Zygote.gradient(θv -> sum(fy(θv, xPM)), θvec)[1] + if gdev isa MLDataDevices.AbstractGPUDevice + # θg = gdev(θ) + # xPMg = gdev(xPM) + # yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ); + θvecg = gdev(θvec) + xPMg = gdev(xPM) + yg = fy(θvecg, xPMg) + @test cdev(yg) == y_exp' + ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] # errors without ";" + @test ygradg isa CA.ComponentArray + @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray + ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version + # ygradgc.P .- ygrad.P + # ygradgc.Ms + end +end + +@testset "neg_logden_obs Matrix" begin + is = repeat(axes(θP_true, 1)', n_site) + θvec = CA.ComponentVector(P = θP_true, Ms = θMs_true) + xPM = map(xP1s -> repeat(xP1s', n_site), xP[1]) + #θ = hcat(θP_true[is], θMs_true') + intθ1 = get_concrete(ComponentArrayInterpreter(vcat(θP_true, θMs_true[:, 1]))) + #θpos = get_positions(intθ1) + intθ = get_concrete(ComponentArrayInterpreter((n_site,), intθ1)) + fcost = (θvec, xPM, y_o, y_unc) -> begin + θ = hcat(CA.getdata(θvec.P[is]), CA.getdata(θvec.Ms')) + y = HVI.DoubleMM.f_doubleMM(θ, xPM, intθ) + #y = HVI.DoubleMM.f_doubleMM(θ, xPM, θpos) + fneglogden(y_o, y', y_unc) + end + cost = fcost(θvec, xPM, y_o, y_unc) + ygrad = Zygote.gradient(θv -> fcost(θv, xPM, y_o, y_unc), θvec)[1] + if gdev isa MLDataDevices.AbstractGPUDevice + # θg = gdev(θ) + # xPMg = gdev(xPM) + # yg = HVI.DoubleMM.f_doubleMM(θg, xPMg, intθ); + θvecg = gdev(θvec) + xPMg = gdev(xPM) + y_og = gdev(y_o) + y_uncg = gdev(y_unc) + costg = fcost(θvecg, xPMg, y_og, y_uncg) + @test costg ≈ cost + ygradg = Zygote.gradient(θv -> fcost(θv, xPMg, y_og, y_uncg), θvecg)[1] # errors without ";" + @test ygradg isa CA.ComponentArray + @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray + ygradgc = HVI.apply_preserve_axes(cdev, ygradg) # can print the cpu version + # ygradgc.P .- ygrad.P + # ygradgc.Ms + end +end + @testset "loss_g" begin g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) @@ -89,23 +170,27 @@ end #----------- fit g and θP to y_o (without uncertainty, without transforming θP) g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario) (; transP, transM) = get_hybridproblem_transforms(prob; scenario) - f = get_hybridproblem_PBmodel(prob; scenario) + n_site, n_site_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + f2 = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) intϕ = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), ϕP = par_templates.θP)) - p = p0 = vcat(ϕg0, HVI.apply_preserve_axes(inverse(transP), par_templates.θP) .- + p = p0 = vcat(ϕg0, + HVI.apply_preserve_axes(inverse(transP), par_templates.θP) .- convert(eltype(ϕg0), 0.1)) # slightly disturb θP_true #p = p0 = vcat(ϕg_opt1, par_templates.θP); # almost true # Pass the site-data for the batches as separate vectors wrapped in a tuple - n_batch = 10 - train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites), batchsize = n_batch) - # get_hybridproblem_train_dataloader recreates synthetic data different θ_true - train_loader2 = get_hybridproblem_train_dataloader(prob; scenario, n_batch = n_site) - pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) + # train_loader = MLUtils.DataLoader( + # (xM, xP, y_o, y_unc, i_sites), batchsize = n_site_batch) + train_loader = get_hybridproblem_train_dataloader(prob; scenario) + @assert train_loader.data == (xM, xP, y_o, y_unc, i_sites) + pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) #loss_gf = get_loss_gf(g, transM, f, y_global_o, intϕ; gdev = identity) loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; pbm_covars) + loss_gf2 = get_loss_gf(g, transM, transP, f2, y_global_o, intϕ; pbm_covars) l1 = loss_gf(p0, first(train_loader)...)[1] (xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) = first(train_loader) Zygote.gradient( @@ -120,7 +205,7 @@ end #optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000); optprob, Adam(0.02), maxiters = 1000) - l1, y_pred_global, y_pred, θMs_pred, θP_pred = loss_gf(res.u, train_loader.data...) + l1, y_pred_global, y_pred, θMs_pred, θP_pred = loss_gf2(res.u, train_loader.data...) #l1, y_pred_global, y_pred, θMs_pred = loss_gf(p0, xM, xP, y_o, y_unc); θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true)) #TODO @test isapprox(par_templates.θP, intϕ(res.u).ϕP, rtol = 0.15) @@ -141,22 +226,17 @@ end end end -using CUDA: CUDA -using Flux -using GPUArraysCore - -gdev = gpu_device() -cdev = cpu_device() if gdev isa MLDataDevices.AbstractGPUDevice + scenario = (:use_Flux,) + g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario) + ϕg0_gpu = gdev(ϕg0) + xM_gpu = gdev(xM) + g_gpu = gdev(g) + @testset "transfer NormalScalingModelApplicator to gpu" begin - scenario = (:use_Flux,) - g, ϕg0 = get_hybridproblem_MLapplicator(rng, prob; scenario) - ϕg = gdev(ϕg0) - xM_gpu = gdev(xM) - g_gpu = gdev(g) @test g_gpu.μ isa GPUArraysCore.AbstractGPUArray - y_gpu = g_gpu(xM_gpu, ϕg) + y_gpu = g_gpu(xM_gpu, ϕg0_gpu) y = g(xM, ϕg0) @test cdev(y_gpu) ≈ y - end; -end + end +end \ No newline at end of file diff --git a/test/test_elbo.jl b/test/test_elbo.jl index af85a73..3c1cf12 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -36,15 +36,17 @@ test_scenario = (scenario) -> begin #θsite_true = get_hybridproblem_par_templates(prob; scenario) + n_covar = 5 + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc + ) = gen_hybridproblem_synthetic(rng, prob; scenario); + g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario); - f = get_hybridproblem_PBmodel(prob; scenario) + f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + f_pred = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) - n_covar = 5 - n_batch = 10 n_θM, n_θP = values(map(length, get_hybridproblem_par_templates(prob; scenario))) - (; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc - ) = gen_hybridproblem_synthetic(rng, prob; scenario); py = neg_logden_indep_normal @@ -149,7 +151,7 @@ test_scenario = (scenario) -> begin trans_PMs_gen = get_transPMs(n_site) @test length(intm_PMs_gen) == 402 @test trans_PMs_gen.length_in == 402 - (; θ, y) = predict_gf(rng, g, f, ϕ_ini, xM, xP, map(get_concrete, interpreters); + (; θ, y) = predict_gf(rng, g, f_pred, ϕ_ini, xM, xP, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices) @test θ isa CA.ComponentMatrix @test θ[:, 1].P.r0 > 0 @@ -162,7 +164,7 @@ test_scenario = (scenario) -> begin n_sample_pred = 200 ϕ = ggdev(CA.getdata(ϕ_ini)) xMg = ggdev(xM) - (; θ, y) = predict_gf(rng, g_gpu, f, ϕ, xMg, xP, map(get_concrete, interpreters); + (; θ, y) = predict_gf(rng, g_gpu, f_pred, ϕ, xMg, xP, map(get_concrete, interpreters); get_transPMs, get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices) @test θ isa CA.ComponentMatrix # only ML parameters are on gpu @test θ[:, 1].P.r0 > 0 @@ -182,7 +184,7 @@ test_scenario = (scenario) -> begin # n_sample_pred = 200 # ϕ = ggdev(CA.getdata(ϕ_ini)) # xMg = ggdev(xM) - # (; θ, y) = predict_gf(rng, g_gpu, f, ϕ, xMg, ggdev(xP), map(get_concrete, interpreters); + # (; θ, y) = predict_gf(rng, g_gpu, f_pred, ϕ, xMg, ggdev(xP), map(get_concrete, interpreters); # get_ca_int_PMs, n_sample_pred, cor_ends, pbm_covar_indices, # get_transPMs = get_transPMs_ident, # cdev = identity); # keep on gpu diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index 16719d2..4749c11 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -23,8 +23,10 @@ scenario = (:default,) n_θM, n_θP = length.(values(get_hybridproblem_par_templates(prob; scenario))) -(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o ) = gen_hybridproblem_synthetic(rng, prob; scenario) +n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + FT = get_hybridproblem_float_type(prob; scenario) @@ -65,7 +67,7 @@ end if ggdev isa MLDataDevices.AbstractGPUDevice @testset "sample_ζ_norm0 gpu" begin # sample only n_batch of 50 - n_batch = 40 + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) ϕb = CA.ComponentVector(P = ϕ_cpu.P, Ms = ϕ_cpu.Ms[:,1:n_batch], unc = ϕ_cpu.unc) intb = ComponentArrayInterpreter(ϕb) ϕ = ggdev(CA.getdata(ϕb))