Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ dataset on which the ensembler should be trained on.
This function currently assumes that `sol.t` matches the time points of all measurements
in `data_ensem`!
"""
function ensemble_weights(sol::EnsembleSolution, data_ensem)
function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = size(data_ensem,2))
obs = first.(data_ensem)
predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs)
data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)])
weights = predictions \ data
F = svd(data)
# Truncate SVD
U, S, V = F.U[:, 1:rank], F.S[1:rank], F.V[:, 1:rank]
weights = (((data*V)*Diagonal(1 ./ S)) * U')
end

function bayesian_ensemble(probs, ps, datas;
Expand All @@ -46,4 +49,4 @@ function bayesian_ensemble(probs, ps, datas;
@info "$(length(all_probs)) total models"

enprob = EnsembleProblem(all_probs)
end
end