diff --git a/Project.toml b/Project.toml index c7f265dd..aecefaa1 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,8 @@ GaussianProcesses = "891a1506-143c-57d2-908e-e1f8e92e6de9" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" +Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" @@ -42,6 +44,8 @@ ForwardDiff = "0.10.38, 1" GaussianProcesses = "0.12" KernelFunctions = "0.10.64" MCMCChains = "4.14, 5, 6, 7" +Manifolds = "0.10.23" +Manopt = "0.5.20" Printf = "1" ProgressBars = "1" PyCall = "1.93" diff --git a/src/Utilities.jl b/src/Utilities.jl index 71585d3e..d87a61dc 100644 --- a/src/Utilities.jl +++ b/src/Utilities.jl @@ -1,7 +1,5 @@ module Utilities - - using DocStringExtensions using LinearAlgebra using Statistics @@ -477,5 +475,6 @@ end include("Utilities/canonical_correlation.jl") include("Utilities/decorrelator.jl") include("Utilities/elementwise_scaler.jl") +include("Utilities/likelihood_informed.jl") end # module diff --git a/src/Utilities/canonical_correlation.jl b/src/Utilities/canonical_correlation.jl index c7ab6a47..eb6304cd 100644 --- a/src/Utilities/canonical_correlation.jl +++ b/src/Utilities/canonical_correlation.jl @@ -181,9 +181,9 @@ initialize_processor!( """ $(TYPEDSIGNATURES) -Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix +Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix or a data vector """ -function encode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix} +function encode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} data_mean = get_data_mean(cc)[1] encoder_mat = get_encoder_mat(cc)[1] return encoder_mat * (data .- data_mean) @@ -192,9 +192,9 @@ end """ $(TYPEDSIGNATURES) -Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix +Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix or a data vector """ -function decode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix} +function decode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} data_mean = get_data_mean(cc)[1] decoder_mat = get_decoder_mat(cc)[1] return decoder_mat * data .+ data_mean diff --git a/src/Utilities/decorrelator.jl b/src/Utilities/decorrelator.jl index 9f237a82..9fb60190 100644 --- a/src/Utilities/decorrelator.jl +++ b/src/Utilities/decorrelator.jl @@ -187,9 +187,9 @@ end """ $(TYPEDSIGNATURES) -Apply the `Decorrelator` encoder, on a columns-are-data matrix +Apply the `Decorrelator` encoder, on a columns-are-data matrix or a data vector """ -function encode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix} +function encode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} data_mean = get_data_mean(dd)[1] encoder_mat = get_encoder_mat(dd)[1] return encoder_mat * (data .- data_mean) @@ -198,9 +198,9 @@ end """ $(TYPEDSIGNATURES) -Apply the `Decorrelator` decoder, on a columns-are-data matrix +Apply the `Decorrelator` decoder, on a columns-are-data matrix or a data vector """ -function decode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix} +function decode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} data_mean = get_data_mean(dd)[1] decoder_mat = get_decoder_mat(dd)[1] return decoder_mat * data .+ data_mean diff --git a/src/Utilities/elementwise_scaler.jl b/src/Utilities/elementwise_scaler.jl index 6cca8a9d..c718aec8 100644 --- a/src/Utilities/elementwise_scaler.jl +++ b/src/Utilities/elementwise_scaler.jl @@ -126,9 +126,9 @@ end """ $(TYPEDSIGNATURES) -Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix +Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix or a data vector """ -function encode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix} +function encode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} out = deepcopy(data) for i in 1:size(out, 1) out[i, :] .-= get_shift(es)[i] @@ -140,9 +140,9 @@ end """ $(TYPEDSIGNATURES) -Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix +Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix or a data vector """ -function decode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix} +function decode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} out = deepcopy(data) for i in 1:size(out, 1) out[i, :] *= get_scale(es)[i] diff --git a/src/Utilities/likelihood_informed.jl b/src/Utilities/likelihood_informed.jl new file mode 100644 index 00000000..550c3bb6 --- /dev/null +++ b/src/Utilities/likelihood_informed.jl @@ -0,0 +1,199 @@ +# included in Utilities.jl + +using Manifolds, Manopt + +export LikelihoodInformed, likelihood_informed + +mutable struct LikelihoodInformed{FT <: Real} <: PairedDataContainerProcessor + encoder_mat::Union{Nothing, AbstractMatrix} + decoder_mat::Union{Nothing, AbstractMatrix} + apply_to::Union{Nothing, AbstractString} + dim_criterion::Tuple{Symbol, <:Number} + α::FT + grad_type::Symbol + use_prior_samples::Bool +end + +function likelihood_informed(retain_KL; alpha = 0.0, grad_type = :localsl, use_prior_samples = true) + if grad_type ∉ [:linreg, :localsl] + @error "Unknown grad_type=$grad_type" + end + + LikelihoodInformed(nothing, nothing, nothing, (:retain_KL, retain_KL), alpha, grad_type, use_prior_samples) +end + +get_encoder_mat(li::LikelihoodInformed) = li.encoder_mat +get_decoder_mat(li::LikelihoodInformed) = li.decoder_mat + +function initialize_processor!( + li::LikelihoodInformed, + in_data::MM, + out_data::MM, + ::Dict{Symbol, <:StructureMatrix}, + output_structure_matrices::Dict{Symbol, <:StructureMatrix}, + input_structure_vectors::Dict{Symbol, <:StructureVector}, + output_structure_vectors::Dict{Symbol, <:StructureVector}, + apply_to::AbstractString, +) where {MM <: AbstractMatrix} + output_dim = size(out_data, 1) + + if isnothing(get_encoder_mat(li)) + α = li.α + y = if α ≈ 0.0 + # For α=0, it doesn't matter what this value is, so we avoid requiring its presence + zeros(size(out_data, 1)) + else + get_structure_vec(output_structure_vectors, :observation) + end + samples_in, samples_out = if li.use_prior_samples + @assert α ≈ 0.0 + ( + get_structure_vec(input_structure_vectors, :prior_samples_in), + get_structure_vec(output_structure_vectors, :prior_samples_out), + ) + else + (in_data, out_data) + end + obs_noise_cov = get_structure_mat(output_structure_matrices, :obs_noise_cov) + noise_cov_inv = inv(obs_noise_cov) + + li.apply_to = apply_to + + grads = if li.grad_type == :linreg + grad = (samples_out .- mean(samples_out; dims = 2)) / (samples_in .- mean(samples_in; dims = 2)) + fill(grad, size(samples_in, 2)) + else + @assert li.grad_type == :localsl + + map(eachcol(samples_in)) do u + # TODO: It might be interesting to introduce a parameter to weight this distance with. + # This can be a scalar or a matrix; in the latter case, we can even use the covariance + # of the samples (or the prior covariance). + weights = exp.(-1/2 * norm.(eachcol(u .- samples_in)).^2) + D = Diagonal(sqrt.(weights)) + uw = (samples_in .- mean(samples_in * Diagonal(weights); dims = 2)) * D + gw = (samples_out .- mean(samples_out * Diagonal(weights); dims = 2)) * D + gw / uw + end + end + + li.encoder_mat = if apply_to == "in" || α ≈ 0 + decomp = if apply_to == "in" + eigen(mean(grad' * noise_cov_inv * ((1-α)obs_noise_cov + α^2 * (y - g) * (y - g)') * noise_cov_inv * grad for (g, grad) in zip(eachcol(samples_out), grads)), sortby = (-)) + else + @assert apply_to == "out" + eigen(mean(grad * grad' for grad in grads), obs_noise_cov, sortby = (-)) + end + + if li.dim_criterion[1] == :retain_KL + retain_KL = li.dim_criterion[2] + sv_cumsum = cumsum(decomp.values) / sum(decomp.values) + trunc_val = findfirst(x -> (x ≥ retain_KL), sv_cumsum) + else + @assert li.dim_criterion[1] == :dimension + trunc_val = li.dim_criterion[2] + end + li.encoder_mat = decomp.vectors[:, 1:trunc_val]' + else + @assert apply_to == "out" + @warn "Using LikelihoodInformed on output data with α≠0 triggers a manifold optimization process that may take some time." + + k = if li.dim_criterion[1] == :retain_KL + 1 + else + @assert li.dim_criterion[1] == :dimension + li.dim_criterion[2] + end + Vs = nothing + while true + M = Grassmann(output_dim, k) + + f = (_, Vs) -> begin + prec = noise_cov_inv - Vs * inv(Vs' * obs_noise_cov * Vs) * Vs' + tr(mean( + grad' * prec * ((1-α)I + α^2 * (y - g)*(y - g)') * prec * grad + for (g, grad) in zip(eachcol(out_data), grads) + )) + end + egrad = (_, Vs) -> begin + B = Vs * inv(Vs' * obs_noise_cov * Vs) * Vs' + prec = noise_cov_inv - B + + -2mean(begin + A = ((1-α)I + α^2 * (y - g)*(y - g)') + S = grad * grad' + (I - obs_noise_cov * B) * (S * prec * A + A * prec * S) + end for (g, grad) in zip(eachcol(out_data), grads)) * B * Vs + end + rgrad = (M, Vs) -> begin + (I - Vs*Vs') * egrad(M, Vs) + end + + Vs = Matrix(qr(randn(output_dim, k))) + quasi_Newton!(M, f, rgrad, Vs; stopping_criterion = StopWhenGradientNormLess(3.0)) + + if li.dim_criterion[1] == :retain_KL + retain_KL = li.dim_criterion[2] + ref = f(M, zeros(output_dim, 0)) + if f(M, Vs) / ref ≤ 1 - retain_KL + break # TODO: Start bisecting? + else + k *= 2 + end + else + @assert li.dim_criterion[1] == :dimension + break + end + end + + Vs' + end + li.decoder_mat = li.encoder_mat' + end +end + +""" +$(TYPEDSIGNATURES) + +Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix or a data vector +""" +function encode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} + encoder_mat = get_encoder_mat(li) + return encoder_mat * data +end + +""" +$(TYPEDSIGNATURES) + +Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix or a data vector +""" +function decode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} + decoder_mat = get_decoder_mat(li) + return decoder_mat * data +end + +""" +$(TYPEDSIGNATURES) + +Apply the `LikelihoodInformed` encoder to a provided structure matrix +""" +function encode_structure_matrix( + li::LikelihoodInformed, + structure_matrix::SM, +) where {SM <: StructureMatrix} + encoder_mat = get_encoder_mat(li) + return encoder_mat * structure_matrix * encoder_mat' +end + +""" +$(TYPEDSIGNATURES) + +Apply the `LikelihoodInformed` decoder to a provided structure matrix +""" +function decode_structure_matrix( + li::LikelihoodInformed, + structure_matrix::SM, +) where {SM <: StructureMatrix} + decoder_mat = get_decoder_mat(li) + return decoder_mat * structure_matrix * decoder_mat' +end