diff --git a/examples/DimensionReduction/Project.toml b/examples/DimensionReduction/Project.toml new file mode 100644 index 000000000..0d266b679 --- /dev/null +++ b/examples/DimensionReduction/Project.toml @@ -0,0 +1,16 @@ +[deps] +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" +Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/examples/DimensionReduction/datafiles/.gitkeep b/examples/DimensionReduction/datafiles/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/examples/DimensionReduction/figures/.gitkeep b/examples/DimensionReduction/figures/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/examples/DimensionReduction/problems/forward_maps.jl b/examples/DimensionReduction/problems/forward_maps.jl new file mode 100644 index 000000000..79c93504c --- /dev/null +++ b/examples/DimensionReduction/problems/forward_maps.jl @@ -0,0 +1 @@ +abstract type ForwardMapType end diff --git a/examples/DimensionReduction/problems/problem_linear.jl b/examples/DimensionReduction/problems/problem_linear.jl new file mode 100644 index 000000000..717e1ec3f --- /dev/null +++ b/examples/DimensionReduction/problems/problem_linear.jl @@ -0,0 +1,52 @@ +using LinearAlgebra +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using Statistics +using Distributions + +include("forward_maps.jl") + +function linear(input_dim, output_dim, rng) + # prior + γ0 = 4.0 + β_γ = -2 + Γ = Diagonal([γ0 * (1.0 * j)^β_γ for j in 1:input_dim]) + prior_dist = MvNormal(zeros(input_dim), Γ) + prior = ParameterDistribution( + Dict( + "distribution" => Parameterized(prior_dist), + "constraint" => repeat([no_constraint()], input_dim), + "name" => "param_$(input_dim)", + ), + ) + + U = qr(randn(rng, (output_dim, output_dim))).Q + V = qr(randn(rng, (input_dim, input_dim))).Q + λ0 = 100.0 + β_λ = -1 + Λ = Diagonal([λ0 * (1.0 * j)^β_λ for j in 1:output_dim]) + A = U * Λ * V[1:output_dim, :] # output x input + model = Linear(input_dim, output_dim, A) + + # generate data sample + obs_noise_cov = Diagonal([Float64(j)^(-1 / 2) for j in 1:output_dim]) + noise = rand(rng, MvNormal(zeros(output_dim), obs_noise_cov)) + # true_parameter = reshape(ones(input_dim), :, 1) + true_parameter = rand(prior_dist) + y = vec(forward_map(true_parameter, model) + noise) + return prior, y, obs_noise_cov, model, true_parameter +end + +struct Linear{AM <: AbstractMatrix} <: ForwardMapType + input_dim::Int + output_dim::Int + G::AM +end + +function forward_map(X::AVorM, model::Linear) where {AVorM <: AbstractVecOrMat} + return model.G * X +end + +function jac_forward_map(X::AbstractMatrix, model::Linear) + return [model.G for _ in eachcol(X)] +end diff --git a/examples/DimensionReduction/problems/problem_linear_exp.jl b/examples/DimensionReduction/problems/problem_linear_exp.jl new file mode 100644 index 000000000..f8d34dab3 --- /dev/null +++ b/examples/DimensionReduction/problems/problem_linear_exp.jl @@ -0,0 +1,62 @@ +using LinearAlgebra +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using Statistics +using Distributions + +# Inverse problem will be taken from (Cui, Tong, 2021) https://arxiv.org/pdf/2101.02417, example 7.1 +include("forward_maps.jl") + +function linear_exp(input_dim, output_dim, rng) + # prior + γ0 = 4.0 + β_γ = -2 + Γ = Diagonal([γ0 * (1.0 * j)^β_γ for j in 1:input_dim]) + prior_dist = MvNormal(zeros(input_dim), Γ) + prior = ParameterDistribution( + Dict( + "distribution" => Parameterized(prior_dist), + "constraint" => repeat([no_constraint()], input_dim), + "name" => "param_$(input_dim)", + ), + ) + + # forward map + # random linear-exp forward map from Stewart 1980: https://www.jstor.org/stable/2156882?seq=2 + U = qr(randn(rng, (output_dim, output_dim))).Q + V = qr(randn(rng, (input_dim, input_dim))).Q + λ0 = 100.0 + β_λ = -1 + Λ = Diagonal([λ0 * (1.0 * j)^β_λ for j in 1:output_dim]) + A = U * Λ * V[1:output_dim, :] # output x input + model = LinearExp(input_dim, output_dim, A) + + # generate data sample + obs_noise_std = 1.0 + obs_noise_cov = (obs_noise_std^2) * I(output_dim) + noise = rand(rng, MvNormal(zeros(output_dim), obs_noise_cov)) + # true_parameter = reshape(ones(input_dim), :, 1) + true_parameter = rand(prior_dist) + y = vec(forward_map(true_parameter, model) + noise) + return prior, y, obs_noise_cov, model, true_parameter +end + + +## G*exp(X) +struct LinearExp{AM <: AbstractMatrix} <: ForwardMapType + input_dim::Int + output_dim::Int + G::AM +end + +# columns of X are samples +function forward_map(X::AVorM, model::LE) where {LE <: LinearExp, AVorM <: AbstractVecOrMat} + return model.G * exp.(X) +end + +# columns of X are samples +function jac_forward_map(X::AM, model::LE) where {AM <: AbstractMatrix, LE <: LinearExp} + # dGi / dXj = G_ij exp(x_j) = G.*exp.(mat with repeated x_j rows) + # return [G * exp.(Diagonal(r)) for r in eachrow(X')] # correct but extra multiplies + return [model.G .* exp.(reshape(c, 1, :)) for c in eachcol(X)] +end diff --git a/examples/DimensionReduction/problems/problem_linlinexp.jl b/examples/DimensionReduction/problems/problem_linlinexp.jl new file mode 100644 index 000000000..4aaf366b5 --- /dev/null +++ b/examples/DimensionReduction/problems/problem_linlinexp.jl @@ -0,0 +1,56 @@ +using LinearAlgebra +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using Statistics +using Distributions + +include("forward_maps.jl") + +function linlinexp(input_dim, output_dim, rng) + # prior + γ0 = 4.0 + β_γ = -2 + Γ = Diagonal([γ0 * (1.0 * j)^β_γ for j in 1:input_dim]) + prior_dist = MvNormal(zeros(input_dim), Γ) + prior = ParameterDistribution( + Dict( + "distribution" => Parameterized(prior_dist), + "constraint" => repeat([no_constraint()], input_dim), + "name" => "param_$(input_dim)", + ), + ) + + U = qr(randn(rng, (output_dim, output_dim))).Q + V = qr(randn(rng, (input_dim, input_dim))).Q + λ0 = 100.0 + β_λ = -1 + Λ = Diagonal([λ0 * (1.0 * j)^β_λ for j in 1:output_dim]) + A = U * Λ * V[1:output_dim, :] # output x input + model = LinLinExp(input_dim, output_dim, A) + + # generate data sample + obs_noise_cov = Diagonal([Float64(j)^(-1 / 2) for j in 1:output_dim]) + noise = rand(rng, MvNormal(zeros(output_dim), obs_noise_cov)) + # true_parameter = reshape(ones(input_dim), :, 1) + true_parameter = rand(prior_dist) + y = vec(forward_map(true_parameter, model) + noise) + return prior, y, obs_noise_cov, model, true_parameter +end + +struct LinLinExp{AM <: AbstractMatrix} <: ForwardMapType + input_dim::Int + output_dim::Int + G::AM +end + +function forward_map(X::AVorM, model::LinLinExp) where {AVorM <: AbstractVecOrMat} + return model.G * (X .* exp.(0.05X)) +end + +function jac_forward_map(X::AbstractVector, model::LinLinExp) + return model.G * Diagonal(exp.(0.05X) .* (1 .+ 0.05X)) +end + +function jac_forward_map(X::AbstractMatrix, model::LinLinExp) + return [jac_forward_map(x, model) for x in eachcol(X)] +end diff --git a/examples/DimensionReduction/problems/problem_lorenz.jl b/examples/DimensionReduction/problems/problem_lorenz.jl new file mode 100644 index 000000000..0a5972585 --- /dev/null +++ b/examples/DimensionReduction/problems/problem_lorenz.jl @@ -0,0 +1,277 @@ +include("../../Lorenz/GModel.jl") # Contains Lorenz 96 source code + +include("./forward_maps.jl") + +# Import modules +using Distributions # probability distributions and associated functions +using LinearAlgebra +using StatsPlots +using Plots +using Random +using JLD2 +using Statistics + +# CES +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using EnsembleKalmanProcesses.Localizers + +const EKP = EnsembleKalmanProcesses + +# G(θ) = H(Ψ(θ,x₀,t₀,t₁)) +# y = G(θ) + η + +# This will change for different Lorenz simulators +struct LorenzConfig{FT1 <: Real, FT2 <: Real} + "Length of a fixed integration timestep" + dt::FT1 + "Total duration of integration (T = N*dt)" + T::FT2 +end + +# This will change for each ensemble member +struct EnsembleMemberConfig{VV <: AbstractVector} + "state-dependent-forcing" + F::VV +end + +# This will change for different "Observations" of Lorenz +struct ObservationConfig{FT1 <: Real, FT2 <: Real} + "initial time to gather statistics (T_start = N_start*dt)" + T_start::FT1 + "end time to gather statistics (T_end = N_end*dt)" + T_end::FT2 +end +######################################################################### +############################ Model Functions ############################ +######################################################################### + +# Forward pass of forward model +# Inputs: +# - params: structure with F (state-dependent-forcing vector) +# - x0: initial condition vector +# - config: structure including dt (timestep Float64(1)) and T (total time Float64(1)) +function lorenz_forward( + params::EnsembleMemberConfig, + x0::VorM, + config::LorenzConfig, + observation_config::ObservationConfig, +) where {VorM <: AbstractVecOrMat} + # run the Lorenz simulation + xn = lorenz_solve(params, x0, config) + # Get statistics + gt = stats(xn, config, observation_config) + return gt +end + +#Calculates statistics for forward model output +# Inputs: +# - xn: timeseries of states for length of simulation through Lorenz96 +function stats(xn::VorM, config::LorenzConfig, observation_config::ObservationConfig) where {VorM <: AbstractVecOrMat} + T_start = observation_config.T_start + T_end = observation_config.T_end + dt = config.dt + N_start = Int(ceil(T_start / dt)) + N_end = Int(ceil(T_end / dt)) + xn_stat = xn[:, N_start:N_end] + N_state = size(xn_stat, 1) + gt = zeros(2 * N_state) + gt[1:N_state] = mean(xn_stat, dims = 2) + gt[(N_state + 1):(2 * N_state)] = std(xn_stat, dims = 2) + return gt +end + +# Forward pass of the Lorenz 96 model +# Inputs: +# - params: structure with F (state-dependent-forcing vector) +# - x0: initial condition vector +# - config: structure including dt (timestep Float64(1)) and T (total time Float64(1)) +function lorenz_solve(params::EnsembleMemberConfig, x0::VorM, config::LorenzConfig) where {VorM <: AbstractVecOrMat} + # Initialize + nstep = Int(ceil(config.T / config.dt)) + state_dim = isa(x0, AbstractVector) ? length(x0) : size(x0, 1) + xn = zeros(size(x0, 1), nstep + 1) + xn[:, 1] = x0 + + # March forward in time + for j in 1:nstep + xn[:, j + 1] = RK4(params, xn[:, j], config) + end + # Output + return xn +end + +# Lorenz 96 system +# f = dx/dt +# Inputs: +# - params: structure with F (state-dependent-forcing vector) +# - x: current state +function f(params::EnsembleMemberConfig, x::VorM) where {VorM <: AbstractVecOrMat} + F = params.F + N = length(x) + f = zeros(N) + # Loop over N positions + for i in 3:(N - 1) + f[i] = -x[i - 2] * x[i - 1] + x[i - 1] * x[i + 1] - x[i] + F[i] + end + # Periodic boundary conditions + f[1] = -x[N - 1] * x[N] + x[N] * x[2] - x[1] + F[1] + f[2] = -x[N] * x[1] + x[1] * x[3] - x[2] + F[2] + f[N] = -x[N - 2] * x[N - 1] + x[N - 1] * x[1] - x[N] + F[N] + # Output + return f +end + +# RK4 solve +# Inputs: +# - params: structure with F (state-dependent-forcing vector) +# - xold: current state +# - config: structure including dt (timestep Float64(1)) and T (total time Float64(1)) +function RK4(params::EnsembleMemberConfig, xold::VorM, config::LorenzConfig) where {VorM <: AbstractVecOrMat} + N = length(xold) + dt = config.dt + + # Predictor steps (note no time-dependence is needed here) + k1 = f(params, xold) + k2 = f(params, xold + k1 * dt / 2.0) + k3 = f(params, xold + k2 * dt / 2.0) + k4 = f(params, xold + k3 * dt) + # Step + xnew = xold + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4) + # Output + return xnew +end + + +######################################################################## +############################ Problem setup ############################# +######################################################################## + +struct Lorenz <: ForwardMapType + rng::Any + config_settings::Any + observation_config::Any + x0::Any + ic_cov_sqrt::Any + nx::Any +end + +# columns of X are samples +function forward_map(X::AbstractVector, model::Lorenz; noise = nothing) + noise = isnothing(noise) ? model.ic_cov_sqrt * randn(model.rng, model.nx) : noise + lorenz_forward( + EnsembleMemberConfig(X .+ 8.0), + (model.x0 .+ noise), + model.config_settings, + model.observation_config, + ) +end + +function forward_map(X::AbstractMatrix, model::Lorenz; noise = nothing) + hcat([forward_map(x, model; noise) for x in eachcol(X)]...) +end + +function jac_forward_map(X::AbstractVector, model::Lorenz) + # Finite-difference Jacobian + nx = model.nx + noise = model.ic_cov_sqrt * randn(model.rng, model.nx) + h = 1e-6 + J = zeros(nx * 2, nx) + for i in 1:nx + x_plus_h = copy(X) + x_plus_h[i] += h + x_minus_h = copy(X) + x_minus_h[i] -= h + J[:, i] = (forward_map(x_plus_h, model; noise) - forward_map(x_minus_h, model; noise)) / (2 * h) + end + return J +end + +function jac_forward_map(X::AbstractMatrix, model::Lorenz) + return [jac_forward_map(x, model) for x in eachcol(X)] +end + +function lorenz(input_dim, output_dim, rng) + #Creating my sythetic data + #initalize model variables + nx = 40 #dimensions of parameter vector + ny = nx * 2 #number of data points + @assert input_dim == nx + @assert output_dim == ny + + gamma = 8 .+ 6 * sin.((4 * pi * range(0, stop = nx - 1, step = 1)) / nx) #forcing (Needs to be of type EnsembleMemberConfig) + true_parameters = EnsembleMemberConfig(gamma) + + t = 0.01 #time step + T_long = 1000.0 #total time + picking_initial_condition = LorenzConfig(t, T_long) + + #beginning state + x_initial = rand(rng, Normal(0.0, 1.0), nx) + + #Find the initial condition for my data + x_spun_up = lorenz_solve(true_parameters, x_initial, picking_initial_condition) #Need to make LorenzConfig object with t, T_long + + #intital condition used for the data + x0 = x_spun_up[:, end] #last element of the run is the initial condition for creating the data + + #Creating my sythetic data + T = 14.0 + lorenz_config_settings = LorenzConfig(t, T) + + # construct how we compute Observations + T_start = 4.0 #2*max + T_end = T + observation_config = ObservationConfig(T_start, T_end) + + model_out_y = lorenz_forward(true_parameters, x0, lorenz_config_settings, observation_config) + + #Observation covariance + # [Don't need to do this bit really] - initial condition perturbations + covT = 1000.0 #time to simulate to calculate a covariance matrix of the system + cov_solve = lorenz_solve(true_parameters, x0, LorenzConfig(t, covT)) + ic_cov = 0.1 * cov(cov_solve, dims = 2) + ic_cov_sqrt = sqrt(ic_cov) + + n_samples = 200 + y_ens = hcat( + [ + lorenz_forward( + true_parameters, + (x0 .+ ic_cov_sqrt * rand(rng, Normal(0.0, 1.0), nx, n_samples))[:, j], + lorenz_config_settings, + observation_config, + ) for j in 1:n_samples + ]..., + ) + + # estimate noise from IC-effect + R + obs_noise_cov = cov(y_ens, dims = 2) + y_mean = mean(y_ens, dims = 2) + y = y_ens[:, 1] + + pl = 2.0 + psig = 3.0 + #Prior covariance + B = zeros(nx, nx) + for ii in 1:nx + for jj in 1:nx + B[ii, jj] = psig^2 * exp(-abs(ii - jj) / pl) + end + end + B_sqrt = sqrt(B) + + #Prior mean + mu = zeros(nx) + + #Creating prior distribution + distribution = Parameterized(MvNormal(mu, B)) + constraint = repeat([no_constraint()], nx) + name = "ml96_prior" + + prior = ParameterDistribution(distribution, constraint, name) + + model = Lorenz(rng, lorenz_config_settings, observation_config, x0, ic_cov_sqrt, nx) + + return prior, y, obs_noise_cov, model, gamma +end diff --git a/examples/DimensionReduction/settings.jl b/examples/DimensionReduction/settings.jl new file mode 100644 index 000000000..cc6704e8e --- /dev/null +++ b/examples/DimensionReduction/settings.jl @@ -0,0 +1,44 @@ +# CONFIGURE THE THREE STEPS + +## -- Configure the inverse problem -- +problem = "linlinexp" # "lorenz" or "linear" or "linear_exp" or "linlinexp" +input_dim = 200 +output_dim = 50 + +## -- Configure parameters of the experiment itself -- +rng_seed = 41 +num_trials = 1 +αs = 0.0:0.25:1.0 +grad_types = (:perfect,) # Out of :perfect, :mean, :linreg, and :localsl +Vgrad_types = () # Out of :egi + +# Specific to step 1 +step1_eki_ensemble_size = 200 +step1_mcmc_sampler = :rw # :rw or :mala +step1_mcmc_samples_per_chain = 50_000 +step1_mcmc_num_chains = 8 +step1_mcmc_subsample_rate = 1000 + +# Specific to step 2 +step2_manopt_num_dims = 8 +step2_Vgrad_num_samples = 8 +step2_egi_ξ = 0.0 +step2_egi_γ = 1.5 + +# Specific to step 3 +step3_diagnostics_to_use = [ + ("Hu_1.0_mcmc_perfect", input_dim, "Hg_0.0_mcmc_perfect", step2_manopt_num_dims), + ("Hu_1.0_mcmc_perfect", input_dim, "Hg_0.25_mcmc_perfect", step2_manopt_num_dims), + ("Hu_1.0_mcmc_perfect", input_dim, "Hg_0.5_mcmc_perfect", step2_manopt_num_dims), + ("Hu_1.0_mcmc_perfect", input_dim, "Hg_0.75_mcmc_perfect", step2_manopt_num_dims), + ("Hu_1.0_mcmc_perfect", input_dim, "Hg_1.0_mcmc_perfect", step2_manopt_num_dims), +] +step3_run_reduced_in_full_space = false +step3_marginalization = :forward_model # :loglikelihood or :forward_model +step3_num_marginalization_samples = 1 +step3_posterior_sampler = :mcmc # :eks or :mcmc +step3_eks_ensemble_size = 800 # only used if `step3_posterior_sampler == :eks` +step3_eks_max_iters = 200 # only used if `step3_posterior_sampler == :eks` +step3_mcmc_sampler = :rw # :rw or :mala; only used if `step3_posterior_sampler == :mcmc` +step3_mcmc_samples_per_chain = 20_000 # only used if `step3_posterior_sampler == :mcmc` +step3_mcmc_num_chains = 24 # only used if `step3_posterior_sampler == :mcmc` diff --git a/examples/DimensionReduction/step1_generate_inverse_problem_data.jl b/examples/DimensionReduction/step1_generate_inverse_problem_data.jl new file mode 100644 index 000000000..39d8ad5af --- /dev/null +++ b/examples/DimensionReduction/step1_generate_inverse_problem_data.jl @@ -0,0 +1,129 @@ +using AdvancedMH +using Distributions +using EnsembleKalmanProcesses +using ForwardDiff +using JLD2 +using LinearAlgebra +using MCMCChains +using Random + +include("./problems/problem_linear.jl") +include("./problems/problem_linear_exp.jl") +include("./problems/problem_lorenz.jl") +include("./problems/problem_linlinexp.jl") + +include("./settings.jl") +include("./util.jl") +rng = Random.MersenneTwister(rng_seed) +problem_fun = if problem == "linear" + linear +elseif problem == "linear_exp" + linear_exp +elseif problem == "lorenz" + lorenz +elseif problem == "linlinexp" + linlinexp +else + throw("Unknown problem=$problem") +end + +mutable struct CheckpointScheduler <: EnsembleKalmanProcesses.LearningRateScheduler + αs::Vector{Float64} + scheduler + + current_index + + CheckpointScheduler(αs, scheduler) = new(αs, scheduler, 2) +end + +function EnsembleKalmanProcesses.calculate_timestep!(ekp, g, Δt_new, scheduler::CheckpointScheduler) + EnsembleKalmanProcesses.calculate_timestep!(ekp, g, Δt_new, scheduler.scheduler) + if scheduler.current_index <= length(scheduler.αs) && get_algorithm_time(ekp)[end] > scheduler.αs[scheduler.current_index] + get_algorithm_time(ekp)[end] = scheduler.αs[scheduler.current_index] + scheduler.current_index += 1 + end + + nothing +end + + + +for trial in 1:num_trials + @info "Trial $trial" + + prior, y, obs_noise_cov, model, true_parameter = problem_fun(input_dim, output_dim, rng) + + # [1] EKP run + n_ensemble = step1_eki_ensemble_size + + initial_ensemble = construct_initial_ensemble(rng, prior, n_ensemble) + ekp = EnsembleKalmanProcess( + initial_ensemble, + y, + obs_noise_cov, + TransformInversion(); + rng, + scheduler = CheckpointScheduler(αs, EKSStableScheduler(2.0, 0.01)), + ) + + n_iters = 0 + while vcat([0.0], get_algorithm_time(ekp))[end] < maximum(αs) + n_iters += 1 + G_ens = hcat([forward_map(param, model) for param in eachcol(get_ϕ_final(prior, ekp))]...) + terminate = update_ensemble!(ekp, G_ens) + if !isnothing(terminate) + throw("EKI terminated prematurely: $(terminate)! Shouldn't happen...") + end + end + @info "EKP iterations: $n_iters" + @info "Loss over iterations: $(get_error(ekp))" + + ekp_samples = Dict() + for α in αs + closest_iter = argmin(0:n_iters) do i + abs(α - (i == 0 ? 0.0 : get_algorithm_time(ekp)[i])) + end + 1 + ekp_samples[α] = (get_u(ekp, closest_iter), closest_iter == n_iters + 1 ? hcat([forward_map(u, model) for u in eachcol(get_u(ekp, closest_iter))]...) : get_g(ekp, closest_iter)) + end + + # [2] MCMC run + mcmc_samples = Dict() + for α in αs + @info "Running MCMC for α = $α" + + prior_cov, prior_inv, obs_inv = cov(prior), inv(cov(prior)), inv(obs_noise_cov) + mcmc_samples[α] = (zeros(input_dim, 0), zeros(output_dim, 0)) + do_mcmc( + input_dim, + x -> begin + g = α == 0 ? 0y : forward_map(x, model) + -2 \ x' * prior_inv * x - 2 \ (y - g)' * obs_inv * (y - g) * α + end, + step1_mcmc_num_chains, + step1_mcmc_samples_per_chain, + step1_mcmc_sampler, + prior_cov, + true_parameter; + subsample_rate = step1_mcmc_subsample_rate, + ) do samp, _ + gsamp = hcat([forward_map(s, model) for s in eachcol(samp)]...) # TODO: This is wasteful, as they have actually already been computed + mcmc_samples[α] = (hcat(mcmc_samples[α][1], samp), hcat(mcmc_samples[α][2], gsamp)) + end + end + @info "MCMC finished" + + # [3] Save everything to a file + #! format: off + save( + "datafiles/ekp_$(problem)_$(trial).jld2", + "ekp", ekp, + "ekp_samples", ekp_samples, + "mcmc_samples", mcmc_samples, + "prior", prior, + "y", y, + "obs_noise_cov", obs_noise_cov, + "model", model, + "true_parameter", true_parameter, + ) + #! format: on +end diff --git a/examples/DimensionReduction/step2_build_and_compare_diagnostic_matrices.jl b/examples/DimensionReduction/step2_build_and_compare_diagnostic_matrices.jl new file mode 100644 index 000000000..1e5afda71 --- /dev/null +++ b/examples/DimensionReduction/step2_build_and_compare_diagnostic_matrices.jl @@ -0,0 +1,261 @@ +using LinearAlgebra +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +using Statistics +using Distributions +using Plots +using JLD2 +using Manopt, Manifolds + +include("./settings.jl") +include("./problems/problem_linear.jl") +include("./problems/problem_linear_exp.jl") +include("./problems/problem_lorenz.jl") +include("./problems/problem_linlinexp.jl") + +if !isfile("datafiles/ekp_$(problem)_1.jld2") + include("step1_generate_inverse_problem_data.jl") +end + +#Utilities +function cossim(x::VV1, y::VV2) where {VV1 <: AbstractVector, VV2 <: AbstractVector} + return dot(x, y) / (norm(x) * norm(y)) +end +function cossim_pos(x::VV1, y::VV2) where {VV1 <: AbstractVector, VV2 <: AbstractVector} + return abs(cossim(x, y)) +end +function cossim_cols(X::AM1, Y::AM2) where {AM1 <: AbstractMatrix, AM2 <: AbstractMatrix} + return [cossim_pos(c1, c2) for (c1, c2) in zip(eachcol(X), eachcol(Y))] +end + +all_diagnostic_matrices_u = Dict() +all_diagnostic_matrices_g = Dict() + +for trial in 1:num_trials + @info "Trial $trial" + diagnostic_matrices_u = Dict() + diagnostic_matrices_g = Dict() + + # Load the EKP iterations + loaded = load("datafiles/ekp_$(problem)_$(trial).jld2") + ekp = loaded["ekp"] + ekp_samp = loaded["ekp_samples"] + mcmc_samp = loaded["mcmc_samples"] + prior = loaded["prior"] + obs_noise_cov = loaded["obs_noise_cov"] + y = loaded["y"] + model = loaded["model"] + + prior_cov = cov(prior) + prior_invrt = sqrt(inv(prior_cov)) + prior_rt = sqrt(prior_cov) + obs_invrt = sqrt(inv(obs_noise_cov)) + obs_inv = inv(obs_noise_cov) + + ekp_samp_grad = Dict() + mcmc_samp_grad = Dict() + ekp_samp_Vgrad = Dict() + mcmc_samp_Vgrad = Dict() + for (dict_samp, dict_samp_grad, dict_samp_Vgrad) in ( + (ekp_samp, ekp_samp_grad, ekp_samp_Vgrad), + (mcmc_samp, mcmc_samp_grad, mcmc_samp_Vgrad), + ) + for (α, (samps, gsamps)) in dict_samp + dict_samp_grad[α] = [] + for grad_type in grad_types + @info "Computing gradients: α=$α, grad_type=$grad_type" + grads = if grad_type == :perfect + jac_forward_map(samps, model) + elseif grad_type == :mean + grad = jac_forward_map(reshape(mean(samps; dims = 2), :, 1), model)[1] + fill(grad, size(samps, 2)) + elseif grad_type == :linreg + grad = (gsamps .- mean(gsamps; dims = 2)) / (samps .- mean(samps; dims = 2)) + fill(grad, size(samps, 2)) + elseif grad_type == :localsl + map(zip(eachcol(samps), eachcol(gsamps))) do (u, g) + weights = exp.(-1/2 * norm.(eachcol(u .- samps)).^2) # TODO: Matrix weighting + D = Diagonal(sqrt.(weights)) + uw = (samps .- mean(samps * Diagonal(weights); dims = 2)) * D + gw = (gsamps .- mean(gsamps * Diagonal(weights); dims = 2)) * D + gw / uw + end + else + throw("Unknown grad_type=$grad_type") + end + + push!(dict_samp_grad[α], (grad_type, grads)) + end + + dict_samp_Vgrad[α] = [] + for Vgrad_type in Vgrad_types + @info "Computing Vgradients: α=$α, Vgrad_type=$Vgrad_type" + grads = if Vgrad_type == :egi + ∇Vs = map(enumerate(zip(eachcol(samps), eachcol(gsamps)))) do (i, (u, g)) + @info "In full EGI procedure; particle $i/$(size(samps, 2))" + + yys = eachcol(rand(MvNormal((1-α)g + α*y, (1-α)obs_noise_cov), α == 1.0 ? 1 : step2_Vgrad_num_samples)) # If α == 1.0, all samples will be the same anyway + map(yys) do yy + Vs = [1/2 * (yy - g)' * obs_inv * (yy - g) for g in eachcol(gsamps)] + + X = samps[:, [1:i-1; i+1:end]] .- u + Z = X ./ norm.(eachcol(X))' + A = hcat(X'Z, (X'Z).^2 / 2) + + ξ, γ = step2_egi_ξ, step2_egi_γ + Γ = γ * (factorial(3) \ Diagonal(norm.(eachcol(X)).^3) + ξ * I) # The paper has γ², but that's wrong + + Y = Vs[[1:i-1; i+1:end]] .- Vs[i] + + ū = pinv(Γ \ A) * (Γ \ Y) + Z * ū[1:end÷2] + end + end + + mean( + mean( + ∇V * ∇V' + for ∇V in ∇Vs_at_x + ) for ∇Vs_at_x in ∇Vs + ) + else + throw("Unknown Vgrad_type=$Vgrad_type") + end + + push!(dict_samp_Vgrad[α], (Vgrad_type, grads)) + end + end + end + + # random samples + @assert 0.0 in αs + prior_samp, prior_gsamp = ekp_samp[0.0] + num_prior_samps = size(prior_samp, 2) + + @info "Construct PCA matrices" + pca_u = prior_samp' + pca_g = prior_gsamp' + + diagnostic_matrices_u["pca_u"] = pca_u, :gray + diagnostic_matrices_g["pca_g"] = pca_g, :gray + + for α in αs + for (sampler, dict_samp, dict_samp_grad, dict_samp_Vgrad) in ( + ("ekp", ekp_samp, ekp_samp_grad, ekp_samp_Vgrad), + ("mcmc", mcmc_samp, mcmc_samp_grad, mcmc_samp_Vgrad), + ) + samp, gsamp = dict_samp[α] + for (Vgrad_type, grads) in dict_samp_Vgrad[α] + name_suffix = "$(α)_$(sampler)_$(Vgrad_type)" + @info "Construct $name_suffix matrices" + + Hu = prior_rt * mean(grads) * prior_rt + # TODO: Hg + + diagnostic_matrices_u["Hu_$name_suffix"] = Hu, :black + end + + for (grad_type, grads) in dict_samp_grad[α] + name_suffix = "$(α)_$(sampler)_$(grad_type)" + @info "Construct $name_suffix matrices" + + Hu = prior_rt * mean( + grad' * obs_inv * ( + (1-α)obs_noise_cov + α^2 * (y - g) * (y - g)' + ) * obs_inv * grad + for (g, grad) in zip(eachcol(gsamp), grads) + ) * prior_rt + + Hg = if α == 0 + obs_invrt * mean(grad * prior_cov * grad' for grad in grads) * obs_invrt + else + Vs0 = qr(randn(output_dim, output_dim)).Q[:, 1:step2_manopt_num_dims] + + f = (_, Vs) -> begin + res = mean( + begin + mat = obs_invrt * grad * prior_rt - Vs*(Vs'*obs_invrt * grad * prior_rt) + a = obs_invrt * (y - g) + + (1-α)norm(mat) + α^2 * norm(a' * mat) + end + for (g, grad) in zip(eachcol(gsamp), grads) + ) + println(res) + res + end + + egrad = Vs -> begin + -2mean( + begin + a = obs_invrt * (y - g) + mat = obs_invrt * grad * prior_cov * grad' * obs_invrt * (I - Vs * Vs') * ((1-α)I + α^2 * a * a') + + mat + mat' + end + for (g, grad) in zip(eachcol(gsamp), grads) + ) * Vs + end + rgrad = (_, Vs) -> begin + egrd = egrad(Vs) + res = egrd - Vs * (Vs' * egrd) + res + end + Vs = quasi_Newton(Grassmann(output_dim, step2_manopt_num_dims), f, rgrad, Vs0; stopping_criterion = StopWhenGradientNormLess(3.0)) + Vs = hcat(Vs, randn(output_dim, output_dim - step2_manopt_num_dims)) + Vs * diagm(vcat(step2_manopt_num_dims:-1:1, zeros(output_dim - step2_manopt_num_dims))) * Vs' + end + + diagnostic_matrices_u["Hu_$name_suffix"] = Hu, :black + diagnostic_matrices_g["Hg_$name_suffix"] = Hg, :black + end + end + end + + for (name, (value, color)) in diagnostic_matrices_u + if !haskey(all_diagnostic_matrices_u, name) + all_diagnostic_matrices_u[name] = ([], color) + end + push!(all_diagnostic_matrices_u[name][1], value) + end + for (name, (value, color)) in diagnostic_matrices_g + if !haskey(all_diagnostic_matrices_g, name) + all_diagnostic_matrices_g[name] = ([], color) + end + push!(all_diagnostic_matrices_g[name][1], value) + end + + save( + "datafiles/diagnostic_matrices_$(problem)_$(trial).jld2", + vcat([[name, value] for (name, (value, _)) in diagnostic_matrices_u]...)..., + vcat([[name, value] for (name, (value, _)) in diagnostic_matrices_g]...)..., + ) +end + +using Plots.Measures +gr(; size = (1.6 * 1200, 600), legend = true, bottom_margin = 10mm, left_margin = 10mm) +default(; titlefont = 20, legendfontsize = 12, guidefont = 14, tickfont = 14) + +trunc = 15 +trunc = min(trunc, input_dim, output_dim) +# color names in https://github.com/JuliaGraphics/Colors.jl/blob/master/src/names_data.jl + +alg = LinearAlgebra.QRIteration() +plots = map([:in, :out]) do in_or_out + diagnostics = in_or_out == :in ? all_diagnostic_matrices_u : all_diagnostic_matrices_g + ref = in_or_out == :in ? "Hu_0.0_ekp_perfect" : "Hg_0.0_ekp_perfect" + + p = plot(; title = "Similarity of spectrum of $(in_or_out)put diagnostic", xlabel = "SV index") + for (name, (mats, color)) in diagnostics + svds = [svd(mat; alg) for mat in mats] + sims = [cossim_cols(s.V, svd(ref_diag; alg).V) for (s, ref_diag) in zip(svds, diagnostics[ref][1])] + name == ref || + plot!(p, mean(sims)[1:trunc]; ribbon = std(sims)[1:trunc], label = "sim ($ref vs. $name)", color) + mean_S = mean([s.S[1:trunc] for s in svds]) + plot!(p, mean_S ./ mean_S[1]; label = "SVs ($name)", linestyle = :dash, linewidth = 3, color) + end + + p +end +plot(plots...; layout = @layout([a b])) +savefig("figures/spectrum_comparison_$problem.png") diff --git a/examples/DimensionReduction/step3_estimate_posteriors.jl b/examples/DimensionReduction/step3_estimate_posteriors.jl new file mode 100644 index 000000000..b4bc128b0 --- /dev/null +++ b/examples/DimensionReduction/step3_estimate_posteriors.jl @@ -0,0 +1,259 @@ +using AdvancedMH +using Distributions +using ForwardDiff +using JLD2 +using LinearAlgebra +using MCMCChains +using Plots +using Random +using Statistics + +include("./settings.jl") +include("./util.jl") +rng = Random.MersenneTwister(rng_seed) + +if !isfile("datafiles/ekp_$(problem)_1.jld2") + include("step1_generate_inverse_problem_data.jl") +end +if !isfile("datafiles/diagnostic_matrices_$(problem)_1.jld2") + include("step2_build_and_compare_diagnostic_matrices.jl") +end + +means_full = Dict() + +for (in_diag, in_r, out_diag, out_r) in step3_diagnostics_to_use + @info "Diagnostic matrices = ($in_diag [1-$in_r], $out_diag [1-$out_r])" + + rel_error_full_rmse = 0 + rel_error_red_rmse = 0 + + for trial in 1:num_trials + # Load the EKP iterations + loaded = load("datafiles/ekp_$(problem)_$(trial).jld2") + ekp = loaded["ekp"] + prior = loaded["prior"] + obs_noise_cov = loaded["obs_noise_cov"] + y = loaded["y"] + model = loaded["model"] + true_parameter = loaded["true_parameter"] + + prior_cov = cov(prior) + prior_inv = inv(prior_cov) + prior_invrt = sqrt(inv(prior_cov)) + prior_rt = sqrt(prior_cov) + obs_invrt = sqrt(inv(obs_noise_cov)) + obs_inv = inv(obs_noise_cov) + + # Load diagnostic container + diagnostic_mats = load("datafiles/diagnostic_matrices_$(problem)_$(trial).jld2") + + Hu = diagnostic_mats[in_diag] + Hg = diagnostic_mats[out_diag] + svdu = svd(Hu; alg = LinearAlgebra.QRIteration()) + svdg = svd(Hg; alg = LinearAlgebra.QRIteration()) + U_r = svdu.V[:, 1:in_r] + V_r = svdg.V[:, 1:out_r] + + # Projection matrices + P = U_r' * prior_invrt + Pinv = prior_rt * U_r + Q = V_r' * obs_invrt + + obs_noise_cov_r = V_r' * V_r # Vr' * invrt(noise) * noise * invrt(noise) * Vr + @assert obs_noise_cov_r ≈ I + prior_cov_r = U_r' * U_r + prior_cov_r_inv = inv(prior_cov_r) + y_r = Q * y + prior_r = ParameterDistribution( + Dict( + "distribution" => Parameterized(MvNormal(zeros(in_r), prior_cov_r)), + "constraint" => repeat([no_constraint()], in_r), + "name" => "param_$(in_r)", + ), + ) + + # TODO: Fix assert below for the actual type of `prior` + # @assert prior isa MvNormal && mean(prior) == zeros(input_dim) + + # Let prior = N(0, C) and let x ~ prior + # Then the distribution of x | P*x=x_r is N(Mmean * x_r, Mcov) + C = prior_cov + Mmean = C * P' * inv(P * C * P') + @assert Pinv ≈ Mmean + Mcov = C - Mmean * P * C + 1e-13 * I + Mcov = (Mcov + Mcov') / 2 # Otherwise, it's not numerically Hermitian + covsamps = rand(MvNormal(zeros(input_dim), Mcov), step3_num_marginalization_samples) + + if step3_posterior_sampler == :mcmc + mean_full = if trial in keys(means_full) + means_full[trial] + else + mean_full = zeros(input_dim) + do_mcmc( + input_dim, + x -> begin + g = forward_map(x, model) + -2 \ x' * prior_inv * x - 2 \ (y - g)' * obs_inv * (y - g) + end, + step3_mcmc_num_chains, + step3_mcmc_samples_per_chain, + step3_mcmc_sampler, + prior_cov, + true_parameter, + ) do samp, num_batches + mean_full += mean(samp; dims = 2) / num_batches + end + means_full[trial] = mean_full + mean_full + end + mean_full_red = P * mean_full + + if step3_run_reduced_in_full_space + mean_red_full = zeros(input_dim) + do_mcmc( + input_dim, + xfull -> begin + xred = P * xfull + samp = covsamps .+ Mmean * xred + gsamp = map(x -> forward_map(x, model), eachcol(samp)) + + return -2 \ xfull' * prior_inv * xfull + if step3_marginalization == :loglikelihood + mean(-2 \ (y_r - Q * g)' * (y_r - Q * g) for (x, g) in zip(eachcol(samp), gsamp)) + elseif step3_marginalization == :forward_model + g = mean(gsamp) + -2 \ (y_r - Q * g)' * (y_r - Q * g) + else + throw("Unknown step3_marginalization=$step3_marginalization") + end + end, + step3_mcmc_num_chains, + step3_mcmc_samples_per_chain, + step3_mcmc_sampler, + prior_cov, + true_parameter, + ) do samp, num_batches + mean_red_full += mean(samp; dims = 2) / num_batches + end + mean_red = P * mean_red_full + else + mean_red = zeros(in_r) + do_mcmc( + in_r, + xred -> begin + samp = covsamps .+ Mmean * xred + gsamp = map(x -> forward_map(x, model), eachcol(samp)) + + return -2 \ xred' * prior_cov_r_inv * xred + if step3_marginalization == :loglikelihood + mean(-2 \ (y_r - Q * g)' * (y_r - Q * g) for (x, g) in zip(eachcol(samp), gsamp)) + elseif step3_marginalization == :forward_model + g = mean(gsamp) + -2 \ (y_r - Q * g)' * (y_r - Q * g) + else + throw("Unknown step3_marginalization=$step3_marginalization") + end + end, + step3_mcmc_num_chains, + step3_mcmc_samples_per_chain, + step3_mcmc_sampler, + prior_cov_r, + P * true_parameter, + ) do samp, num_batches + mean_red += mean(samp; dims = 2) / num_batches + end + mean_red_full = Pinv * mean_red # This only works since it's the mean (linear) — if not, we'd have to use the covsamps here (same in a few other places) + end + elseif step3_posterior_sampler == :eks + step3_marginalization == :forward_model || throw( + "EKS sampling from the reduced posterior is only supported when marginalizing over the forward model.", + ) + + mean_full = if trial in keys(means_full) + means_full[trial] + else + u, _ = do_eks( + input_dim, + x -> forward_map(x, model), + y, + obs_noise_cov, + prior, + rng, + step3_eks_ensemble_size, + step3_eks_max_iters, + ) + mean_full = mean(u; dims = 2) + means_full[trial] = mean_full + mean_full + end + mean_full_red = P * mean_full + + if step3_run_reduced_in_full_space + u, _ = do_eks( + input_dim, + xfull -> begin + xred = P * xfull + samp = covsamps .+ Mmean * xred + gsamp = map(x -> forward_map(x, model), eachcol(samp)) + return Q * mean(gsamp) + end, + y_r, + 1.0 * I(out_r), + prior, + rng, + step3_eks_ensemble_size, + step3_eks_max_iters, + ) + mean_red_full = mean(u; dims = 2) + mean_red = P * mean_red_full + else + u, _ = do_eks( + in_r, + xred -> begin + samp = covsamps .+ Mmean * xred + gsamp = map(x -> forward_map(x, model), eachcol(samp)) + return Q * mean(gsamp) + end, + y_r, + 1.0 * I(out_r), + prior_r, + rng, + step3_eks_ensemble_size, + step3_eks_max_iters, + ) + mean_red = mean(u; dims = 2) + mean_red_full = Pinv * mean_red + end + else + throw("Unknown step3_posterior_sampler=$step3_posterior_sampler") + end + + rel_error_full = norm(mean_full - mean_red_full) / norm(mean_full) + rel_error_red = norm(mean_full_red - mean_red) / norm(mean_full_red) + + @info """ + True: $(true_parameter[1:5]) + Mean (in full space): $(mean_full[1:5]) + Red. mean (in full space): $(mean_red_full[1:5]) + + Mean (in red. space): $mean_full_red + Red. mean (in red. space): $mean_red + + Relative error on mean in full space: $rel_error_full + Relative error on mean in reduced space: $rel_error_red + """ + + rel_error_full_rmse += rel_error_full^2 + rel_error_red_rmse += rel_error_red^2 + end + + rel_error_full_rmse = sqrt(rel_error_full_rmse / num_trials) + rel_error_red_rmse = sqrt(rel_error_red_rmse / num_trials) + + open("datafiles/output_error_$(problem).log", "a") do f + println(f, "$in_diag, $in_r, $out_diag, $out_r, $rel_error_full_rmse, $rel_error_red_rmse") + end + + # [A] The relative error seems larger in the reduced space + # The reason is likely the whitening that happens. Small absolute errors in the full space + # can be amplified in the reduced space due to the different scales in the prior. I think + # the full space is probably the one we should be concerned about. +end diff --git a/examples/DimensionReduction/util.jl b/examples/DimensionReduction/util.jl new file mode 100644 index 000000000..912d9faee --- /dev/null +++ b/examples/DimensionReduction/util.jl @@ -0,0 +1,62 @@ +using AdvancedMH +using Distributions +using ForwardDiff +using MCMCChains + +function do_mcmc( + callback, + dim, + logpost, + num_chains, + num_samples_per_chain, + mcmc_sampler, + prior_cov, + initial_guess; + subsample_rate = 1, +) + density_model = DensityModel(logpost) + sampler = if mcmc_sampler == :mala + MALA(x -> MvNormal(0.0001 * prior_cov * x, 0.0001 * 2 * prior_cov)) + elseif mcmc_sampler == :rw + RWMH(MvNormal(zeros(dim), 0.01prior_cov)) + else + throw("Unknown mcmc_sampler=$mcmc_sampler") + end + + num_batches = (num_chains + 7) ÷ 8 + for batch in 1:num_batches + num_chains_in_batch = min(8, num_chains - (batch - 1) * 8) + chain = sample( + density_model, + sampler, + MCMCThreads(), + num_samples_per_chain, + num_chains_in_batch; + chain_type = Chains, + initial_params = [initial_guess for _ in 1:num_chains_in_batch], + ) + samp = vcat( + [vec(MCMCChains.get(chain, Symbol("param_$i"))[1]'[:, (end ÷ 2):subsample_rate:end])' for i in 1:dim]..., + ) + callback(samp, num_batches) + end +end + +function do_eks(dim, G, y, obs_noise_cov, prior, rng, num_ensemble, num_iters_max) + initial_ensemble = construct_initial_ensemble(rng, prior, num_ensemble) + ekp = EnsembleKalmanProcess( + initial_ensemble, + y, + obs_noise_cov, + Sampler(prior); + rng, + scheduler = EKSStableScheduler(2.0, 0.01), + ) + + for i in 1:num_iters_max + g = hcat([G(params) for params in eachcol(get_ϕ_final(prior, ekp))]...) + isnothing(update_ensemble!(ekp, g)) || break + end + + return get_u_final(ekp), get_g_final(ekp) +end