From 42a31acaeac41163e52f3daa02dce02d331152e4 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sat, 22 May 2021 11:48:03 -0700 Subject: [PATCH 1/7] added ot_reg --- src/OptimalTransport.jl | 137 ++-------------------------------------- src/ot_reg.jl | 132 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 +- 3 files changed, 139 insertions(+), 132 deletions(-) create mode 100644 src/ot_reg.jl diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1414a92d..b075a1b2 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2 export emd, emd2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 -export quadreg +export ot_reg_plan const MOI = MathOptInterface @@ -506,137 +506,12 @@ function sinkhorn_barycenter( return u_all[1, :] .* (K_all[1] * v_all[1, :]) end -""" - quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) - -Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`, -using the semismooth Newton algorithm [Lorenz 2016]. - -This implementation makes use of IterativeSolvers.jl and SparseArrays.jl. - -Parameters:\n -θ: starting Armijo parameter.\n -tol: tolerance of marginal error.\n -maxiter: maximum interation number.\n -κ: control parameter of Armijo.\n -δ: small constant for the numerical stability of conjugate gradient iterative solver.\n - -Tips: -If the algorithm does not converge, try some different values of θ. - -Reference: -Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4. -""" -function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5) - if !(sum(mu) ≈ sum(nu)) - throw(ArgumentError("Error: mu and nu must lie in the simplex")) - end - - N = length(mu) - M = length(nu) - - # initialize dual potentials as uniforms - a = ones(M) ./ M - b = ones(N) ./ N - γ = spzeros(M, N) - - da = spzeros(M) - db = spzeros(N) - - converged = false - - function DualObjective(a, b) - A = a .* ones(N)' + ones(M) .* b' - C' - - return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b)) - end - - # computes minimizing directions, update γ - function search_dir!(a, b, da, db) - P = a * ones(N)' .+ ones(M) * b' .- C' - - σ = 1.0 * sparse(P .>= 0) - γ = sparse(max.(P, 0) ./ ϵ) - - G = vcat( - hcat(spdiagm(0 => σ * ones(N)), σ), - hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), - ) - - h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) - - x = cg(G + δ * I, -ϵ * h) - - da = x[1:M] - return db = x[(M + 1):end] - end - - function search_dir(a, b) - P = a * ones(N)' .+ ones(M) * b' .- C' - - σ = 1.0 * sparse(P .>= 0) - γ = sparse(max.(P, 0) ./ ϵ) - - G = vcat( - hcat(spdiagm(0 => σ * ones(N)), σ), - hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), - ) - - h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) - - x = cg(G + δ * I, -ϵ * h) - - return x[1:M], x[(M + 1):end] - end - - # computes optimal maginitude in the minimizing directions - function search_t(a, b, da, db, θ) - d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu)) - - ϕ₀ = DualObjective(a, b) - t = 1 - - while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d - t *= κ - - if t < 1e-15 - # @warn "@ i = $i, t = $t , armijo did not converge" - break - end - end - return t - end - - for i in 1:maxiter - - # search_dir!(a, b, da, db) - da, db = search_dir(a, b) - - t = search_t(a, b, da, db, θ) - - a += t * da - b += t * db - - err1 = norm(γ * ones(N) - nu, Inf) - err2 = norm(sparse(γ') * ones(M) - mu, Inf) - - if err1 <= tol && err2 <= tol - converged = true - @warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n" - break - elseif i == maxiter - @warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n" - end - - @debug " t = $t" - @debug "marginal @ i = $i: err1 = $err1, err2 = $err2 " - end - - if !converged - @warn "SemiSmooth Newton algorithm did not converge" +function ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) + if (reg_func == "L2") && (method == "lorenz") + return quadreg(mu, nu, C, eps; kwargs...) + else + @warn "Unimplemented" end - - return sparse(γ') end end diff --git a/src/ot_reg.jl b/src/ot_reg.jl new file mode 100644 index 00000000..846d0a0a --- /dev/null +++ b/src/ot_reg.jl @@ -0,0 +1,132 @@ +""" + quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) + +Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`, +using the semismooth Newton algorithm [Lorenz 2016]. + +This implementation makes use of IterativeSolvers.jl and SparseArrays.jl. + +Parameters:\n +θ: starting Armijo parameter.\n +tol: tolerance of marginal error.\n +maxiter: maximum interation number.\n +κ: control parameter of Armijo.\n +δ: small constant for the numerical stability of conjugate gradient iterative solver.\n + +Tips: +If the algorithm does not converge, try some different values of θ. + +Reference: +Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4. +""" +function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5) + if !(sum(mu) ≈ sum(nu)) + throw(ArgumentError("Error: mu and nu must lie in the simplex")) + end + + N = length(mu) + M = length(nu) + + # initialize dual potentials as uniforms + a = ones(M) ./ M + b = ones(N) ./ N + γ = spzeros(M, N) + + da = spzeros(M) + db = spzeros(N) + + converged = false + + function DualObjective(a, b) + A = a .* ones(N)' + ones(M) .* b' - C' + + return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b)) + end + + # computes minimizing directions, update γ + function search_dir!(a, b, da, db) + P = a * ones(N)' .+ ones(M) * b' .- C' + + σ = 1.0 * sparse(P .>= 0) + γ = sparse(max.(P, 0) ./ ϵ) + + G = vcat( + hcat(spdiagm(0 => σ * ones(N)), σ), + hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), + ) + + h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) + + x = cg(G + δ * I, -ϵ * h) + + da = x[1:M] + return db = x[(M + 1):end] + end + + function search_dir(a, b) + P = a * ones(N)' .+ ones(M) * b' .- C' + + σ = 1.0 * sparse(P .>= 0) + γ = sparse(max.(P, 0) ./ ϵ) + + G = vcat( + hcat(spdiagm(0 => σ * ones(N)), σ), + hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), + ) + + h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) + + x = cg(G + δ * I, -ϵ * h) + + return x[1:M], x[(M + 1):end] + end + + # computes optimal maginitude in the minimizing directions + function search_t(a, b, da, db, θ) + d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu)) + + ϕ₀ = DualObjective(a, b) + t = 1 + + while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d + t *= κ + + if t < 1e-15 + # @warn "@ i = $i, t = $t , armijo did not converge" + break + end + end + return t + end + + for i in 1:maxiter + + # search_dir!(a, b, da, db) + da, db = search_dir(a, b) + + t = search_t(a, b, da, db, θ) + + a += t * da + b += t * db + + err1 = norm(γ * ones(N) - nu, Inf) + err2 = norm(sparse(γ') * ones(M) - mu, Inf) + + if err1 <= tol && err2 <= tol + converged = true + @warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n" + break + elseif i == maxiter + @warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n" + end + + @debug " t = $t" + @debug "marginal @ i = $i: err1 = $err1, err2 = $err2 " + end + + if !converged + @warn "SemiSmooth Newton algorithm did not converge" + end + + return sparse(γ') +end diff --git a/test/runtests.jl b/test/runtests.jl index c62b72f7..ba78651a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -165,7 +165,7 @@ end # compute optimal transport map (Julia implementation + POT) eps = 0.25 - γ = quadreg(μ, ν, C, eps) + γ = ot_reg_plan(μ, ν, C, eps) γ_pot = POT.Smooth.smooth_ot_dual(μ, ν, C, eps; stopThr=1e-9) # need to use a larger tolerance here because of a quirk with the POT solver @test norm(γ - γ_pot, Inf) < 1e-4 From f809e197993089ec56e6861f58806f9d6a97e053 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 14:51:26 -0700 Subject: [PATCH 2/7] updated tests and docstrings --- src/OptimalTransport.jl | 22 +++++++++++++++++++++- test/runtests.jl | 5 +++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index b075a1b2..32c0003e 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -506,7 +506,27 @@ function sinkhorn_barycenter( return u_all[1, :] .* (K_all[1] * v_all[1, :]) end -function ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) +""" + ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) + +Compute the optimal transport plan between `mu` and `nu` for optimal transport with a +general choice of regulariser `math Ω(γ)`. Solves for `gamma` that minimises + +```math +\\inf_{γ ∈ Π(μ, ν)} \\langle γ, C \\rangle + ε Ω(γ) +``` + +Supported choices of `math Ω` are: +- L2: `math Ω(γ) = \\frac{1}{2} \\| γ \\|_2^2`, `reg_func = "L2"` + +Supported solution methods are: +- L2: `method = "lorenz"` for the semi-smooth Newton method of Lorenz et al. + +References + +Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. Applied Mathematics & Optimization, pp.1-31. +""" +function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) if (reg_func == "L2") && (method == "lorenz") return quadreg(mu, nu, C, eps; kwargs...) else diff --git a/test/runtests.jl b/test/runtests.jl index ba78651a..ac0342ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -165,8 +165,13 @@ end # compute optimal transport map (Julia implementation + POT) eps = 0.25 +<<<<<<< HEAD γ = ot_reg_plan(μ, ν, C, eps) γ_pot = POT.Smooth.smooth_ot_dual(μ, ν, C, eps; stopThr=1e-9) +======= + γ = ot_reg_plan(μ, ν, C, eps; reg_func="L2", method="lorenz") + γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps; max_iter=5000)) +>>>>>>> d6c9ee3 (updated tests and docstrings) # need to use a larger tolerance here because of a quirk with the POT solver @test norm(γ - γ_pot, Inf) < 1e-4 end From 4e758f8ecf8847285d0324c8e5f08896e14b1df7 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sat, 22 May 2021 19:52:09 -0700 Subject: [PATCH 3/7] fix docstring --- docs/src/index.md | 4 ++-- examples/basic/script.jl | 4 ++-- src/OptimalTransport.jl | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index fd2c4a51..6ad38504 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -23,7 +23,7 @@ sinkhorn_unbalanced sinkhorn_unbalanced2 ``` -## Quadratically regularised optimal transport +## Optimal transport with general regularisation ```@docs -quadreg +ot_reg_plan ``` diff --git a/examples/basic/script.jl b/examples/basic/script.jl index d484f739..625c461b 100644 --- a/examples/basic/script.jl +++ b/examples/basic/script.jl @@ -94,7 +94,7 @@ sinkhorn2(μ, ν, C, ε) # resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as # a sparse matrix. -quadreg(μ, ν, C, ε; maxiter=500); +ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500); # ## Stabilized Sinkhorn algorithm # @@ -190,7 +190,7 @@ heatmap( # Notice how the "edges" of the transport plan are sharper if we use quadratic regularisation # instead of entropic regularisation: -γquad = Matrix(quadreg(μ, ν, C, 5; maxiter=500)) +γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500)) heatmap( μsupport, νsupport, diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 32c0003e..f6ae6ab2 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -517,7 +517,7 @@ general choice of regulariser `math Ω(γ)`. Solves for `gamma` that minimises ``` Supported choices of `math Ω` are: -- L2: `math Ω(γ) = \\frac{1}{2} \\| γ \\|_2^2`, `reg_func = "L2"` +- L2: ``Ω(γ) = \\frac{1}{2} \\| γ \\|_2^2``, `reg_func = "L2"` Supported solution methods are: - L2: `method = "lorenz"` for the semi-smooth Newton method of Lorenz et al. From 3e8a3c8a0b57d0ef3e64e1963b7bdea632bb74a3 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 12:45:10 -0700 Subject: [PATCH 4/7] remove additional file --- src/OptimalTransport.jl | 134 ++++++++++++++++++++++++++++++++++++++++ src/ot_reg.jl | 132 --------------------------------------- 2 files changed, 134 insertions(+), 132 deletions(-) delete mode 100644 src/ot_reg.jl diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index f6ae6ab2..146ddeed 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -534,4 +534,138 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) end end + +""" + quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) + +Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`, +using the semismooth Newton algorithm [Lorenz 2016]. + +This implementation makes use of IterativeSolvers.jl and SparseArrays.jl. + +Parameters:\n +θ: starting Armijo parameter.\n +tol: tolerance of marginal error.\n +maxiter: maximum interation number.\n +κ: control parameter of Armijo.\n +δ: small constant for the numerical stability of conjugate gradient iterative solver.\n + +Tips: +If the algorithm does not converge, try some different values of θ. + +Reference: +Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4. +""" +function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5) + if !(sum(mu) ≈ sum(nu)) + throw(ArgumentError("Error: mu and nu must lie in the simplex")) + end + + N = length(mu) + M = length(nu) + + # initialize dual potentials as uniforms + a = ones(M) ./ M + b = ones(N) ./ N + γ = spzeros(M, N) + + da = spzeros(M) + db = spzeros(N) + + converged = false + + function DualObjective(a, b) + A = a .* ones(N)' + ones(M) .* b' - C' + + return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b)) + end + + # computes minimizing directions, update γ + function search_dir!(a, b, da, db) + P = a * ones(N)' .+ ones(M) * b' .- C' + + σ = 1.0 * sparse(P .>= 0) + γ = sparse(max.(P, 0) ./ ϵ) + + G = vcat( + hcat(spdiagm(0 => σ * ones(N)), σ), + hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), + ) + + h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) + + x = cg(G + δ * I, -ϵ * h) + + da = x[1:M] + return db = x[(M + 1):end] + end + + function search_dir(a, b) + P = a * ones(N)' .+ ones(M) * b' .- C' + + σ = 1.0 * sparse(P .>= 0) + γ = sparse(max.(P, 0) ./ ϵ) + + G = vcat( + hcat(spdiagm(0 => σ * ones(N)), σ), + hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), + ) + + h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) + + x = cg(G + δ * I, -ϵ * h) + + return x[1:M], x[(M + 1):end] + end + + # computes optimal maginitude in the minimizing directions + function search_t(a, b, da, db, θ) + d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu)) + + ϕ₀ = DualObjective(a, b) + t = 1 + + while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d + t *= κ + + if t < 1e-15 + # @warn "@ i = $i, t = $t , armijo did not converge" + break + end + end + return t + end + + for i in 1:maxiter + + # search_dir!(a, b, da, db) + da, db = search_dir(a, b) + + t = search_t(a, b, da, db, θ) + + a += t * da + b += t * db + + err1 = norm(γ * ones(N) - nu, Inf) + err2 = norm(sparse(γ') * ones(M) - mu, Inf) + + if err1 <= tol && err2 <= tol + converged = true + @warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n" + break + elseif i == maxiter + @warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n" + end + + @debug " t = $t" + @debug "marginal @ i = $i: err1 = $err1, err2 = $err2 " + end + + if !converged + @warn "SemiSmooth Newton algorithm did not converge" + end + + return sparse(γ') +end + end diff --git a/src/ot_reg.jl b/src/ot_reg.jl deleted file mode 100644 index 846d0a0a..00000000 --- a/src/ot_reg.jl +++ /dev/null @@ -1,132 +0,0 @@ -""" - quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) - -Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`, -using the semismooth Newton algorithm [Lorenz 2016]. - -This implementation makes use of IterativeSolvers.jl and SparseArrays.jl. - -Parameters:\n -θ: starting Armijo parameter.\n -tol: tolerance of marginal error.\n -maxiter: maximum interation number.\n -κ: control parameter of Armijo.\n -δ: small constant for the numerical stability of conjugate gradient iterative solver.\n - -Tips: -If the algorithm does not converge, try some different values of θ. - -Reference: -Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4. -""" -function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5) - if !(sum(mu) ≈ sum(nu)) - throw(ArgumentError("Error: mu and nu must lie in the simplex")) - end - - N = length(mu) - M = length(nu) - - # initialize dual potentials as uniforms - a = ones(M) ./ M - b = ones(N) ./ N - γ = spzeros(M, N) - - da = spzeros(M) - db = spzeros(N) - - converged = false - - function DualObjective(a, b) - A = a .* ones(N)' + ones(M) .* b' - C' - - return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b)) - end - - # computes minimizing directions, update γ - function search_dir!(a, b, da, db) - P = a * ones(N)' .+ ones(M) * b' .- C' - - σ = 1.0 * sparse(P .>= 0) - γ = sparse(max.(P, 0) ./ ϵ) - - G = vcat( - hcat(spdiagm(0 => σ * ones(N)), σ), - hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), - ) - - h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) - - x = cg(G + δ * I, -ϵ * h) - - da = x[1:M] - return db = x[(M + 1):end] - end - - function search_dir(a, b) - P = a * ones(N)' .+ ones(M) * b' .- C' - - σ = 1.0 * sparse(P .>= 0) - γ = sparse(max.(P, 0) ./ ϵ) - - G = vcat( - hcat(spdiagm(0 => σ * ones(N)), σ), - hcat(sparse(σ'), spdiagm(0 => sparse(σ') * ones(M))), - ) - - h = vcat(γ * ones(N) - nu, sparse(γ') * ones(M) - mu) - - x = cg(G + δ * I, -ϵ * h) - - return x[1:M], x[(M + 1):end] - end - - # computes optimal maginitude in the minimizing directions - function search_t(a, b, da, db, θ) - d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu)) - - ϕ₀ = DualObjective(a, b) - t = 1 - - while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d - t *= κ - - if t < 1e-15 - # @warn "@ i = $i, t = $t , armijo did not converge" - break - end - end - return t - end - - for i in 1:maxiter - - # search_dir!(a, b, da, db) - da, db = search_dir(a, b) - - t = search_t(a, b, da, db, θ) - - a += t * da - b += t * db - - err1 = norm(γ * ones(N) - nu, Inf) - err2 = norm(sparse(γ') * ones(M) - mu, Inf) - - if err1 <= tol && err2 <= tol - converged = true - @warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n" - break - elseif i == maxiter - @warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n" - end - - @debug " t = $t" - @debug "marginal @ i = $i: err1 = $err1, err2 = $err2 " - end - - if !converged - @warn "SemiSmooth Newton algorithm did not converge" - end - - return sparse(γ') -end From 71c3378217862e8f377d8cdfc29ac13b6711eee3 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 12:54:38 -0700 Subject: [PATCH 5/7] added ot_reg_cost --- src/OptimalTransport.jl | 21 ++++++++++++++++++++- test/runtests.jl | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 146ddeed..5437daf1 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2 export emd, emd2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 -export ot_reg_plan +export ot_reg_plan, ot_reg_cost const MOI = MathOptInterface @@ -534,6 +534,25 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) end end +""" + ot_reg_cost(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) + +Compute the optimal transport cost between `mu` and `nu` for optimal transport with a +general choice of regulariser `math Ω(γ)`. + +See also: [`ot_reg_plan`](@ref) + +""" +function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) + γ = if (reg_func == "L2") && (method == "lorenz") + quadreg(mu, nu, C, eps; kwargs...) + else + @warn "Unimplemented" + nothing + end + return dot(γ, C) +end + """ quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) diff --git a/test/runtests.jl b/test/runtests.jl index ac0342ee..c33931f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -174,6 +174,9 @@ end >>>>>>> d6c9ee3 (updated tests and docstrings) # need to use a larger tolerance here because of a quirk with the POT solver @test norm(γ - γ_pot, Inf) < 1e-4 + c = ot_reg_cost(μ, ν, C, eps; reg_func="L2", method="lorenz") + c_pot = dot(γ_pot, C) + @test c ≈ c_pot atol = 1e-4 end end From f345878c9a83ceada9af0b2d8014b161cb22dcd4 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 13:05:59 -0700 Subject: [PATCH 6/7] formatting --- src/OptimalTransport.jl | 3 +-- src/variational.jl | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 src/variational.jl diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 5437daf1..d2797344 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -548,12 +548,11 @@ function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) quadreg(mu, nu, C, eps; kwargs...) else @warn "Unimplemented" - nothing + nothing end return dot(γ, C) end - """ quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) diff --git a/src/variational.jl b/src/variational.jl new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/variational.jl @@ -0,0 +1 @@ + From 5adf6355501d9a5dec3b8452018b4de8d89aab0b Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 14:53:41 -0700 Subject: [PATCH 7/7] remove file --- src/variational.jl | 1 - 1 file changed, 1 deletion(-) delete mode 100644 src/variational.jl diff --git a/src/variational.jl b/src/variational.jl deleted file mode 100644 index 8b137891..00000000 --- a/src/variational.jl +++ /dev/null @@ -1 +0,0 @@ -