@@ -7,6 +7,7 @@ module OptimalTransport
77using Distances
88using LinearAlgebra
99using IterativeSolvers, SparseArrays
10+ using LogExpFunctions: LogExpFunctions
1011using MathOptInterface
1112
1213export sinkhorn, sinkhorn2
@@ -171,19 +172,23 @@ function sinkhorn_gibbs(mu, nu, K; tol=1e-9, check_marginal_step=10, maxiter=100
171172end
172173
173174"""
174- sinkhorn(mu, nu , C, eps ; tol=1e-9, check_marginal_step=10, maxiter=1000 )
175+ sinkhorn(μ, ν , C, ε ; tol=1e-9, check_marginal_step=10, maxiter=1_000 )
175176
176- Compute entropically regularised transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic
177- regularization parameter `eps`.
178-
179- Return optimal transport coupling `γ` of the same dimensions as `C` which solves
177+ Compute the optimal transport plan for the entropic regularization optimal transport problem
178+ with source and target marginals `μ` and `ν`, cost matrix `C` of size
179+ `(length(μ), length(ν))`, and entropic regularization parameter `ε`.
180180
181+ The optimal transport plan `γ` is of the same size as `C` and solves
181182```math
182- \\ inf_{\\ gamma \\ in \\ Pi(\\ mu, \\ nu)} \\ langle \\ gamma, C \\ rangle - \\ epsilon H(\\ gamma)
183+ \\ inf_{\\ gamma \\ in \\ Pi(\\ mu, \\ nu)} \\ langle \\ gamma, C \\ rangle
184+ + \\ varepsilon \\ Omega(\\ gamma),
183185```
186+ where ``\\ Omega(\\ gamma) = \\ sum_{i,j} \\ gamma_{i,j} \\ log \\ gamma_{i,j}`` is the entropic
187+ regularization term.
184188
185- where ``H`` is the entropic regulariser, ``H(\\ gamma) = -\\ sum_{i, j} \\ gamma_{ij} \\ log(\\ gamma_{ij})``.
186-
189+ Every `check_marginal_step` steps a convergence check of the error of the marginal
190+ `μ` with absolute tolerance `tol` is performed. After `maxiter` iterations, the
191+ computation is stopped.
187192"""
188193function sinkhorn (mu, nu, C, eps; kwargs... )
189194 # compute Gibbs kernel
@@ -196,24 +201,22 @@ function sinkhorn(mu, nu, C, eps; kwargs...)
196201end
197202
198203"""
199- sinkhorn2(mu, nu, C, eps; plan=nothing, kwargs...)
200-
201- Compute entropically regularised transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic
202- regularization parameter `eps`.
203-
204- Return optimal value of
205-
206- ```math
207- \\ inf_{\\ gamma \\ in \\ Pi(\\ mu, \\ nu)} \\ langle \\ gamma, C \\ rangle - \\ epsilon H(\\ gamma)
208- ```
204+ sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
209205
210- where ``H`` is the entropic regulariser, ``H(\\ gamma) = -\\ sum_{i, j} \\ gamma_{ij} \\ log(\\ gamma_{ij})``.
206+ Solve the entropic regularization optimal transport problem with source and target
207+ marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic
208+ regularization parameter `ε`, and return the optimal cost.
211209
212210A pre-computed optimal transport `plan` may be provided.
213211
212+ !!! note
213+ As the `sinkhorn2` function in the Python Optimal Transport package, this function
214+ returns the optimal transport cost without the regularization term. The cost
215+ with the regularization term can be computed by setting `regularization=true`.
216+
214217See also: [`sinkhorn`](@ref)
215218"""
216- function sinkhorn2 (μ, ν, C, ε; plan= nothing , kwargs... )
219+ function sinkhorn2 (μ, ν, C, ε; regularization = false , plan= nothing , kwargs... )
217220 γ = if plan === nothing
218221 sinkhorn (μ, ν, C, ε; kwargs... )
219222 else
@@ -225,7 +228,14 @@ function sinkhorn2(μ, ν, C, ε; plan=nothing, kwargs...)
225228 )
226229 plan
227230 end
228- return dot (γ, C)
231+
232+ cost = if regularization
233+ dot (γ, C) + ε * sum (LogExpFunctions. xlogx, γ)
234+ else
235+ dot (γ, C)
236+ end
237+
238+ return cost
229239end
230240
231241"""
0 commit comments