Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ GaussianProcesses = "891a1506-143c-57d2-908e-e1f8e92e6de9"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Expand All @@ -42,6 +44,8 @@ ForwardDiff = "0.10.38, 1"
GaussianProcesses = "0.12"
KernelFunctions = "0.10.64"
MCMCChains = "4.14, 5, 6, 7"
Manifolds = "0.10.23"
Manopt = "0.5.20"
Printf = "1"
ProgressBars = "1"
PyCall = "1.93"
Expand Down
3 changes: 1 addition & 2 deletions src/Utilities.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
module Utilities



using DocStringExtensions
using LinearAlgebra
using Statistics
Expand Down Expand Up @@ -477,5 +475,6 @@ end
include("Utilities/canonical_correlation.jl")
include("Utilities/decorrelator.jl")
include("Utilities/elementwise_scaler.jl")
include("Utilities/likelihood_informed.jl")

end # module
8 changes: 4 additions & 4 deletions src/Utilities/canonical_correlation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ initialize_processor!(
"""
$(TYPEDSIGNATURES)

Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix
Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix or a data vector
"""
function encode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix}
function encode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
data_mean = get_data_mean(cc)[1]
encoder_mat = get_encoder_mat(cc)[1]
return encoder_mat * (data .- data_mean)
Expand All @@ -192,9 +192,9 @@ end
"""
$(TYPEDSIGNATURES)

Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix
Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix or a data vector
"""
function decode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix}
function decode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
data_mean = get_data_mean(cc)[1]
decoder_mat = get_decoder_mat(cc)[1]
return decoder_mat * data .+ data_mean
Expand Down
8 changes: 4 additions & 4 deletions src/Utilities/decorrelator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ end
"""
$(TYPEDSIGNATURES)

Apply the `Decorrelator` encoder, on a columns-are-data matrix
Apply the `Decorrelator` encoder, on a columns-are-data matrix or a data vector
"""
function encode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix}
function encode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
data_mean = get_data_mean(dd)[1]
encoder_mat = get_encoder_mat(dd)[1]
return encoder_mat * (data .- data_mean)
Expand All @@ -198,9 +198,9 @@ end
"""
$(TYPEDSIGNATURES)

Apply the `Decorrelator` decoder, on a columns-are-data matrix
Apply the `Decorrelator` decoder, on a columns-are-data matrix or a data vector
"""
function decode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix}
function decode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
data_mean = get_data_mean(dd)[1]
decoder_mat = get_decoder_mat(dd)[1]
return decoder_mat * data .+ data_mean
Expand Down
8 changes: 4 additions & 4 deletions src/Utilities/elementwise_scaler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ end
"""
$(TYPEDSIGNATURES)

Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix
Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix or a data vector
"""
function encode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix}
function encode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
out = deepcopy(data)
for i in 1:size(out, 1)
out[i, :] .-= get_shift(es)[i]
Expand All @@ -140,9 +140,9 @@ end
"""
$(TYPEDSIGNATURES)

Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix
Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix or a data vector
"""
function decode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix}
function decode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
out = deepcopy(data)
for i in 1:size(out, 1)
out[i, :] *= get_scale(es)[i]
Expand Down
199 changes: 199 additions & 0 deletions src/Utilities/likelihood_informed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# included in Utilities.jl

using Manifolds, Manopt

export LikelihoodInformed, likelihood_informed

mutable struct LikelihoodInformed{FT <: Real} <: PairedDataContainerProcessor
encoder_mat::Union{Nothing, AbstractMatrix}
decoder_mat::Union{Nothing, AbstractMatrix}
apply_to::Union{Nothing, AbstractString}
dim_criterion::Tuple{Symbol, <:Number}
α::FT
grad_type::Symbol
use_prior_samples::Bool
end

function likelihood_informed(retain_KL; alpha = 0.0, grad_type = :localsl, use_prior_samples = true)
if grad_type ∉ [:linreg, :localsl]
@error "Unknown grad_type=$grad_type"
end

LikelihoodInformed(nothing, nothing, nothing, (:retain_KL, retain_KL), alpha, grad_type, use_prior_samples)
end

get_encoder_mat(li::LikelihoodInformed) = li.encoder_mat
get_decoder_mat(li::LikelihoodInformed) = li.decoder_mat

function initialize_processor!(
li::LikelihoodInformed,
in_data::MM,
out_data::MM,
::Dict{Symbol, <:StructureMatrix},
output_structure_matrices::Dict{Symbol, <:StructureMatrix},
input_structure_vectors::Dict{Symbol, <:StructureVector},
output_structure_vectors::Dict{Symbol, <:StructureVector},
apply_to::AbstractString,
) where {MM <: AbstractMatrix}
output_dim = size(out_data, 1)

if isnothing(get_encoder_mat(li))
α = li.α
y = if α ≈ 0.0
# For α=0, it doesn't matter what this value is, so we avoid requiring its presence
zeros(size(out_data, 1))
else
get_structure_vec(output_structure_vectors, :observation)
end
samples_in, samples_out = if li.use_prior_samples
@assert α ≈ 0.0
(
get_structure_vec(input_structure_vectors, :prior_samples_in),
get_structure_vec(output_structure_vectors, :prior_samples_out),
)
else
(in_data, out_data)
end
obs_noise_cov = get_structure_mat(output_structure_matrices, :obs_noise_cov)
noise_cov_inv = inv(obs_noise_cov)

li.apply_to = apply_to

grads = if li.grad_type == :linreg
grad = (samples_out .- mean(samples_out; dims = 2)) / (samples_in .- mean(samples_in; dims = 2))
fill(grad, size(samples_in, 2))
else
@assert li.grad_type == :localsl

map(eachcol(samples_in)) do u
# TODO: It might be interesting to introduce a parameter to weight this distance with.
# This can be a scalar or a matrix; in the latter case, we can even use the covariance
# of the samples (or the prior covariance).
weights = exp.(-1/2 * norm.(eachcol(u .- samples_in)).^2)
D = Diagonal(sqrt.(weights))
uw = (samples_in .- mean(samples_in * Diagonal(weights); dims = 2)) * D
gw = (samples_out .- mean(samples_out * Diagonal(weights); dims = 2)) * D
gw / uw
end
end

li.encoder_mat = if apply_to == "in" || α ≈ 0
decomp = if apply_to == "in"
eigen(mean(grad' * noise_cov_inv * ((1-α)obs_noise_cov + α^2 * (y - g) * (y - g)') * noise_cov_inv * grad for (g, grad) in zip(eachcol(samples_out), grads)), sortby = (-))
else
@assert apply_to == "out"
eigen(mean(grad * grad' for grad in grads), obs_noise_cov, sortby = (-))
end

if li.dim_criterion[1] == :retain_KL
retain_KL = li.dim_criterion[2]
sv_cumsum = cumsum(decomp.values) / sum(decomp.values)
trunc_val = findfirst(x -> (x ≥ retain_KL), sv_cumsum)
else
@assert li.dim_criterion[1] == :dimension
trunc_val = li.dim_criterion[2]
end
li.encoder_mat = decomp.vectors[:, 1:trunc_val]'
else
@assert apply_to == "out"
@warn "Using LikelihoodInformed on output data with α≠0 triggers a manifold optimization process that may take some time."

k = if li.dim_criterion[1] == :retain_KL
1
else
@assert li.dim_criterion[1] == :dimension
li.dim_criterion[2]
end
Vs = nothing
while true
M = Grassmann(output_dim, k)

f = (_, Vs) -> begin
prec = noise_cov_inv - Vs * inv(Vs' * obs_noise_cov * Vs) * Vs'
tr(mean(
grad' * prec * ((1-α)I + α^2 * (y - g)*(y - g)') * prec * grad
for (g, grad) in zip(eachcol(out_data), grads)
))
end
egrad = (_, Vs) -> begin
B = Vs * inv(Vs' * obs_noise_cov * Vs) * Vs'
prec = noise_cov_inv - B

-2mean(begin
A = ((1-α)I + α^2 * (y - g)*(y - g)')
S = grad * grad'
(I - obs_noise_cov * B) * (S * prec * A + A * prec * S)
end for (g, grad) in zip(eachcol(out_data), grads)) * B * Vs
end
rgrad = (M, Vs) -> begin
(I - Vs*Vs') * egrad(M, Vs)
end

Vs = Matrix(qr(randn(output_dim, k)))
quasi_Newton!(M, f, rgrad, Vs; stopping_criterion = StopWhenGradientNormLess(3.0))

if li.dim_criterion[1] == :retain_KL
retain_KL = li.dim_criterion[2]
ref = f(M, zeros(output_dim, 0))
if f(M, Vs) / ref ≤ 1 - retain_KL
break # TODO: Start bisecting?
else
k *= 2
end
else
@assert li.dim_criterion[1] == :dimension
break
end
end

Vs'
end
li.decoder_mat = li.encoder_mat'
end
end

"""
$(TYPEDSIGNATURES)

Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix or a data vector
"""
function encode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
encoder_mat = get_encoder_mat(li)
return encoder_mat * data
end

"""
$(TYPEDSIGNATURES)

Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix or a data vector
"""
function decode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
decoder_mat = get_decoder_mat(li)
return decoder_mat * data
end

"""
$(TYPEDSIGNATURES)

Apply the `LikelihoodInformed` encoder to a provided structure matrix
"""
function encode_structure_matrix(
li::LikelihoodInformed,
structure_matrix::SM,
) where {SM <: StructureMatrix}
encoder_mat = get_encoder_mat(li)
return encoder_mat * structure_matrix * encoder_mat'
end

"""
$(TYPEDSIGNATURES)

Apply the `LikelihoodInformed` decoder to a provided structure matrix
"""
function decode_structure_matrix(
li::LikelihoodInformed,
structure_matrix::SM,
) where {SM <: StructureMatrix}
decoder_mat = get_decoder_mat(li)
return decoder_mat * structure_matrix * decoder_mat'
end
Loading