diff --git a/docs/src/index.md b/docs/src/index.md index fd2c4a51..6ad38504 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -23,7 +23,7 @@ sinkhorn_unbalanced sinkhorn_unbalanced2 ``` -## Quadratically regularised optimal transport +## Optimal transport with general regularisation ```@docs -quadreg +ot_reg_plan ``` diff --git a/examples/basic/script.jl b/examples/basic/script.jl index d484f739..625c461b 100644 --- a/examples/basic/script.jl +++ b/examples/basic/script.jl @@ -94,7 +94,7 @@ sinkhorn2(μ, ν, C, ε) # resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as # a sparse matrix. -quadreg(μ, ν, C, ε; maxiter=500); +ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500); # ## Stabilized Sinkhorn algorithm # @@ -190,7 +190,7 @@ heatmap( # Notice how the "edges" of the transport plan are sharper if we use quadratic regularisation # instead of entropic regularisation: -γquad = Matrix(quadreg(μ, ν, C, 5; maxiter=500)) +γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500)) heatmap( μsupport, νsupport, diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1414a92d..d2797344 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2 export emd, emd2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 -export quadreg +export ot_reg_plan, ot_reg_cost const MOI = MathOptInterface @@ -506,6 +506,53 @@ function sinkhorn_barycenter( return u_all[1, :] .* (K_all[1] * v_all[1, :]) end +""" + ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) + +Compute the optimal transport plan between `mu` and `nu` for optimal transport with a +general choice of regulariser `math Ω(γ)`. Solves for `gamma` that minimises + +```math +\\inf_{γ ∈ Π(μ, ν)} \\langle γ, C \\rangle + ε Ω(γ) +``` + +Supported choices of `math Ω` are: +- L2: ``Ω(γ) = \\frac{1}{2} \\| γ \\|_2^2``, `reg_func = "L2"` + +Supported solution methods are: +- L2: `method = "lorenz"` for the semi-smooth Newton method of Lorenz et al. + +References + +Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. Applied Mathematics & Optimization, pp.1-31. +""" +function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) + if (reg_func == "L2") && (method == "lorenz") + return quadreg(mu, nu, C, eps; kwargs...) + else + @warn "Unimplemented" + end +end + +""" + ot_reg_cost(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...) + +Compute the optimal transport cost between `mu` and `nu` for optimal transport with a +general choice of regulariser `math Ω(γ)`. + +See also: [`ot_reg_plan`](@ref) + +""" +function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) + γ = if (reg_func == "L2") && (method == "lorenz") + quadreg(mu, nu, C, eps; kwargs...) + else + @warn "Unimplemented" + nothing + end + return dot(γ, C) +end + """ quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5) diff --git a/test/runtests.jl b/test/runtests.jl index c62b72f7..c33931f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -165,10 +165,18 @@ end # compute optimal transport map (Julia implementation + POT) eps = 0.25 - γ = quadreg(μ, ν, C, eps) +<<<<<<< HEAD + γ = ot_reg_plan(μ, ν, C, eps) γ_pot = POT.Smooth.smooth_ot_dual(μ, ν, C, eps; stopThr=1e-9) +======= + γ = ot_reg_plan(μ, ν, C, eps; reg_func="L2", method="lorenz") + γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps; max_iter=5000)) +>>>>>>> d6c9ee3 (updated tests and docstrings) # need to use a larger tolerance here because of a quirk with the POT solver @test norm(γ - γ_pot, Inf) < 1e-4 + c = ot_reg_cost(μ, ν, C, eps; reg_func="L2", method="lorenz") + c_pot = dot(γ_pot, C) + @test c ≈ c_pot atol = 1e-4 end end