Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,34 @@ gdev = :use_gpu ∈ scenario ? gpu_device() : identity
cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity

#------ setup synthetic data and training data loader
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario);
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
#n_site = get_hybridproblem_n_site(DoubleMM.DoubleMMCase(); scenario)
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
i_sites = 1:n_site
xM_cpu = xM;
xM = xM_cpu |> gdev;
get_train_loader = (; n_batch, kwargs...) -> MLUtils.DataLoader(
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
train_dataloader = MLUtils.DataLoader(
(xM, xP, y_o, y_unc, 1:n_site);
batchsize = n_batch, partial = false)
σ_o = exp.(y_unc[:, 1] / 2)

# assign the train_loader, otherwise it eatch time creates another version of synthetic data
prob0 = HVI.update(HybridProblem(DoubleMM.DoubleMMCase(); scenario); get_train_loader)
prob0 = HVI.update(prob0_; train_dataloader);
#tmp = HVI.get_hybridproblem_ϕunc(prob0; scenario)

#------- pointwise hybrid model fit
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01), n_batch = 30)
solver_point = HybridPointSolver(; alg = OptimizationOptimisers.Adam(0.01))
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 30)
#solver_point = HybridPointSolver(; alg = Adam(0.01), n_batch = 10)
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
n_batches_in_epoch = n_site ÷ solver_point.n_batch
n_batches_in_epoch = n_site ÷ n_batch
n_epoch = 80
(; ϕ, resopt, probo) = solve(prob0, solver_point; scenario,
rng, callback = callback_loss(n_batches_in_epoch * 10),
maxiters = n_batches_in_epoch * n_epoch);
# update the problem with optimized parameters
prob0o = probo;
y_pred_global, y_pred, θMs = gf(prob0o, xM, xP; scenario);
y_pred_global, y_pred, θMs = gf(prob0o, scenario);
plt = scatterplot(θMs_true[1, :], θMs[1, :]);
lineplot!(plt, 0, 1)
scatterplot(θMs_true[2, :], θMs[2, :])
Expand Down Expand Up @@ -149,10 +148,10 @@ probh = prob0o # start from point optimized to infer uncertainty
#probh = prob1o # start from point optimized to infer uncertainty
#probh = prob0 # start from no information
solver_post = HybridPosteriorSolver(;
alg = OptimizationOptimisers.Adam(0.01), n_batch = min(50, n_site), n_MC = 3)
alg = OptimizationOptimisers.Adam(0.01), n_MC = 3)
#solver_point = HybridPointSolver(; alg = Adam(), n_batch = 200)
n_batches_in_epoch = n_site ÷ solver_post.n_batch
n_epoch = 80
n_batches_in_epoch = n_site ÷ n_batch
n_epoch = 40
(; ϕ, θP, resopt, interpreters, probo) = solve(probh, solver_post; scenario,
rng, callback = callback_loss(n_batches_in_epoch * 5),
maxiters = n_batches_in_epoch * n_epoch,
Expand Down Expand Up @@ -213,6 +212,7 @@ end
n_sample_pred = 400
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
(θ2_indep, y2_indep) = (θ, y)
#(θ2_indep, y2_indep) = (θ2, y2) # workaround to use covarK2 when loading failed
end

() -> begin # otpimize using LUX
Expand Down Expand Up @@ -246,7 +246,7 @@ exp.(ϕunc_VI.coef_logσ2_logMs[1, :])

# test predicting correct obs-uncertainty of predictive posterior
n_sample_pred = 400
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o, xM, xP; scenario, n_sample_pred);
(; θ, y, entropy_ζ) = predict_gf(rng, prob2o; scenario, n_sample_pred);
(θ2, y2) = (θ, y)
size(y) # n_obs x n_site, n_sample_pred
size(θ) # n_θP + n_site * n_θM x n_sample
Expand Down Expand Up @@ -506,12 +506,13 @@ chain = sample(model, NUTS(), MCMCThreads(), ceil(Integer,n_sample_NUTS/n_thread
using JLD2
fname = "intermediate/doubleMM_chain_zeta_$(last(scenario)).jld2"
jldsave(fname, false, IOStream; chain)
chain = load(fname, "chain"; iotype = IOStream)
chain = load(fname, "chain"; iotype = IOStream);
end

#ζi = first(eachrow(Array(chain)))
f_allsites = get_hybridproblem_PBmodel(prob0; scenario, use_all_sites = true)
ζs = mapreduce(ζi -> transposeMs(ζi, intm_PMs_gen, true), hcat, eachrow(Array(chain)));
(; θ, y) = HVI.predict_ζf(ζs, f, xP, trans_PMs_gen, intm_PMs_gen);
(; θ, y) = HVI.predict_ζf(ζs, f_allsites, xP, trans_PMs_gen, intm_PMs_gen);
(ζs_hmc, θ_hmc, y_hmc) = (ζs, θ, y);


Expand Down
Binary file removed dev/negLogDensity.pdf
Binary file not shown.
Binary file removed dev/r1_density.pdf
Binary file not shown.
Binary file removed dev/ys_density.pdf
Binary file not shown.
47 changes: 34 additions & 13 deletions src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ For a specific prob, provide functions that specify details
- `get_hybridproblem_train_dataloader` (may use `construct_dataloader_from_synthetic`)
- `get_hybridproblem_priors`
- `get_hybridproblem_n_covar`
- `get_hybridproblem_n_site`
- `get_hybridproblem_n_site_and_batch`
optionally
- `gen_hybridproblem_synthetic`
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
Expand Down Expand Up @@ -125,11 +125,11 @@ function get_hybridproblem_pbmpar_covars(::AbstractHybridProblem; scenario)
end

"""
get_hybridproblem_n_site(::AbstractHybridProblem; scenario)
get_hybridproblem_n_site_and_batch(::AbstractHybridProblem; scenario)

Provide the number of sites.
"""
function get_hybridproblem_n_site end
function get_hybridproblem_n_site_and_batch end


"""
Expand Down Expand Up @@ -172,30 +172,51 @@ function get_hybridproblem_train_dataloader end
scenario = (), n_batch)

Construct a dataloader based on `gen_hybridproblem_synthetic`.
gdev is applied to xM.
If :f_on_gpu is in scenario tuple, gdev is also applied to `xP`, `y_o`, and `y_unc`,
to put the entire data to gpu.
Alternatively, gdev could be applied to the dataloader, then for each
iteration the subset of data is separately transferred to gpu.
"""
function construct_dataloader_from_synthetic(rng::AbstractRNG, prob::AbstractHybridProblem;
scenario = (), n_batch,
gdev = :use_gpu ∈ scenario ? gpu_device() : identity,
#gdev = :use_gpu ∈ scenario ? gpu_device() : identity,
)
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario)
n_site = size(xM,2)
@assert length(xP) == n_site
@assert size(y_o,2) == n_site
@assert size(y_unc,2) == n_site
i_sites = 1:n_site
xM_dev = gdev(xM)
xP_dev, y_o_dev, y_unc_dev = :f_on_gpu ∈ scenario ?
(gdev(xP), gdev(y_o), gdev(y_unc)) : (xP, y_o, y_unc)
train_loader = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
train_loader = MLUtils.DataLoader((xM, xP, y_o, y_unc, i_sites);
batchsize = n_batch, partial = false)
return (train_loader)
end


"""
gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader,
scenario = (),
gdev = gpu_device(),
gdev_M = :use_gpu ∈ scenario ? gdev : identity,
gdev_P = :f_on_gpu ∈ scenario ? gdev : identity,
batchsize = dataloader.batchsize,
partial = dataloader.partial
)

Put relevant parts of the DataLoader to gpu, depending on scenario.
"""
function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader;
scenario = (),
gdev = gpu_device(),
gdev_M = :use_gpu ∈ scenario ? gdev : identity,
gdev_P = :f_on_gpu ∈ scenario ? gdev : identity,
batchsize = dataloader.batchsize,
partial = dataloader.partial
)
xM, xP, y_o, y_unc, i_sites = dataloader.data
xM_dev = gdev_M(xM)
xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc))
train_loader_dev = MLUtils.DataLoader((xM_dev, xP_dev, y_o_dev, y_unc_dev, i_sites);
batchsize, partial)
return(train_loader_dev)
end

# function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())
# rng::AbstractRNG = Random.default_rng()
# get_hybridproblem_train_dataloader(rng, prob; scenario)
Expand Down
47 changes: 42 additions & 5 deletions src/ComponentArrayInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

Interface for Type that implements
- `as_ca(::AbstractArray, interpreter) -> ComponentArray`
- `ComponentArrays.getaxes(interpreter)`
- `Base.length(interpreter) -> Int`

When called on a vector, forwards to `as_ca`.

There is a default implementation for Base.length based on ComponentArrays.getaxes.
"""
abstract type AbstractComponentArrayInterpreter end

Expand All @@ -18,6 +21,11 @@ Returns a ComponentArray with underlying data `v`.
"""
function as_ca end

function Base.length(cai::AbstractComponentArrayInterpreter)
prod(_axis_length.(CA.getaxes(cai)))
end


(interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter)

"""
Expand All @@ -36,9 +44,13 @@ function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {A
CA.ComponentArray(vr, AX)
end

function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
#sum(length, typeof(AX).parameters[1])
prod(_axis_length.(AX))
# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
# #sum(length, typeof(AX).parameters[1])
# prod(_axis_length.(AX))
# end

function CA.getaxes(int::StaticComponentArrayInterpreter{AX}) where {AX}
AX
end

get_concrete(cai::StaticComponentArrayInterpreter) = cai
Expand All @@ -63,10 +75,11 @@ function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
CA.ComponentArray(vr, cai.axes)
end

function Base.length(cai::ComponentArrayInterpreter)
prod(_axis_length.(cai.axes))
function CA.getaxes(cai::ComponentArrayInterpreter)
cai.axes
end


get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}()


Expand Down Expand Up @@ -120,6 +133,10 @@ function ComponentArrayInterpreter(
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N
ComponentArrayInterpreter(CA.getaxes(ca), n_dims)
end
function ComponentArrayInterpreter(
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N
ComponentArrayInterpreter(CA.getaxes(cai), n_dims)
end
function ComponentArrayInterpreter(
axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
Expand All @@ -131,12 +148,17 @@ function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N
ComponentArrayInterpreter(n_dims, CA.getaxes(ca))
end
function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N
ComponentArrayInterpreter(n_dims, CA.getaxes(cai))
end
function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M}
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
ComponentArrayInterpreter(axes_ext)
end


# ambuiguity with two empty Tuples (edge prob that does not make sense)
# Empty ComponentVector with no other array dimensions -> empty componentVector
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
Expand All @@ -156,6 +178,8 @@ _axis_length(::CA.FlatAxis) = 0
_axis_length(::CA.UnitRange) = 0

"""
flatten1(cv::CA.ComponentVector)

Removes the highest level of keys.
Keeps the reference to the underlying data, but changes the axis.
If first-level vector has no sub-names, an error (Aguement Error tuple must be non-empty)
Expand All @@ -174,3 +198,16 @@ function flatten1(cv::CA.ComponentVector)
CA.ComponentVector(cv, first(CA.getaxes(cv_new)))
end
end


"""
get_positions(cai::AbstractComponentArrayInterpreter)

Create a NamedTuple of integer indices for each component.
Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector.
"""
function get_positions(cai::AbstractComponentArrayInterpreter)
@assert length(CA.getaxes(cai)) == 1
cv = cai(1:length(cai))
(; (k => cv[k] for k in keys(cv))... )
end
Loading