@@ -120,65 +120,120 @@ function emd2(μ, ν, C, optimizer; plan=nothing)
120120end
121121
122122"""
123- sinkhorn_gibbs(mu, nu, K; tol=1e-9, check_marginal_step=10, maxiter=1000)
123+ sinkhorn_gibbs(
124+ μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
125+ )
126+
127+ Compute the dual potentials for the entropically regularized optimal transport problem
128+ with source and target marginals `μ` and `ν` and Gibbs kernel `K` using the Sinkhorn
129+ algorithm.
124130
125- Compute dual potentials `u` and `v` for histograms `mu` and `nu` and Gibbs kernel `K` using
126- the Sinkhorn algorithm (Peyre et al., 2019)
131+ The Gibbs kernel `K` is defined as
132+ ```math
133+ K = \\ exp(-C / \\ varepsilon),
134+ ```
135+ where ``C`` is the cost matrix and ``\\ varepsilon`` the entropic regularization parameter.
136+ The corresponding optimal transport plan can be computed from the dual potentials ``u``
137+ and ``v`` as
138+ ```math
139+ \\ gamma = \\ operatorname{diag}(u) K \\ operatorname{diag}(v).
140+ ```
127141
128- The Gibbs kernel `K` is given by `K = exp.(- C / eps)` where `C` is the cost matrix and
129- `eps` the entropic regularization parameter. The optimal transport plan for histograms `u`
130- and `v` and cost matrix `C` with regularization parameter `eps` can be computed as
131- `Diagonal(u) * K * Diagonal(v)`.
142+ Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
143+ the iterate of the transport plan `G` satisfies
144+ ```julia
145+ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
146+ ```
147+ The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
148+ the computation is stopped.
132149"""
133- function sinkhorn_gibbs (mu, nu, K; tol= 1e-9 , check_marginal_step= 10 , maxiter= 1000 )
134- if ! (sum (mu) ≈ sum (nu))
135- throw (ArgumentError (" Error: mu and nu must lie in the simplex" ))
150+ function sinkhorn_gibbs (
151+ μ,
152+ ν,
153+ K;
154+ tol= nothing ,
155+ atol= tol,
156+ rtol= nothing ,
157+ check_marginal_step= nothing ,
158+ check_convergence= check_marginal_step,
159+ maxiter:: Int = 1_000 ,
160+ )
161+ if tol != = nothing
162+ Base. depwarn (
163+ " keyword argument `tol` is deprecated, please use `atol` and `rtol`" ,
164+ :sinkhorn_gibbs ,
165+ )
136166 end
167+ if check_marginal_step != = nothing
168+ Base. depwarn (
169+ " keyword argument `check_marginal_step` is deprecated, please use `check_convergence`" ,
170+ :sinkhorn_gibbs ,
171+ )
172+ end
173+ sum (μ) ≈ sum (ν) ||
174+ throw (ArgumentError (" source and target marginals must have the same mass" ))
175+
176+ # set default values of tolerances
177+ T = float (Base. promote_eltype (μ, ν, K))
178+ _atol = atol === nothing ? 0 : atol
179+ _rtol = rtol === nothing ? (_atol > zero (_atol) ? zero (T) : sqrt (eps (T))) : rtol
137180
138181 # initial iteration
139- temp_v = vec ( sum (K; dims= 2 ) )
140- u = mu ./ temp_v
141- temp_u = K' * u
142- v = nu ./ temp_u
182+ u = μ ./ sum (K; dims= 2 )
183+ v = ν ./ (K ' * u)
184+ tmp1 = K * v
185+ tmp2 = similar (u)
143186
187+ norm_μ = sum (abs, μ) # for convergence check
144188 isconverged = false
189+ check_step = check_convergence === nothing ? 10 : check_convergence
145190 for iter in 0 : maxiter
146- # check mu marginal
147- if iter % check_marginal_step == 0
148- mul! (temp_v, K, v)
149- @. temp_v = abs (mu - u * temp_v)
150-
151- err = maximum (temp_v)
152- @debug " Sinkhorn algorithm: iteration $iter " err
191+ if iter % check_step == 0
192+ # check source marginal
193+ # do not overwrite `tmp1` but reuse it for computing `u` if not converged
194+ @. tmp2 = u * tmp1
195+ norm_uKv = sum (abs, tmp2)
196+ @. tmp2 = μ - tmp2
197+ norm_diff = sum (abs, tmp2)
198+
199+ @debug " Sinkhorn algorithm (" *
200+ string (iter) *
201+ " /" *
202+ string (maxiter) *
203+ " : absolute error of source marginal = " *
204+ string (norm_diff)
153205
154206 # check stopping criterion
155- if err < tol
207+ if norm_diff < max (_atol, _rtol * max (norm_μ, norm_uKv))
208+ @debug " Sinkhorn algorithm ($iter /$maxiter ): converged"
156209 isconverged = true
157210 break
158211 end
159212 end
160213
161214 # perform next iteration
162215 if iter < maxiter
163- mul! (temp_v, K, v)
164- @. u = mu / temp_v
165- mul! (temp_u, K ' , u)
166- @. v = nu / temp_u
216+ @. u = μ / tmp1
217+ mul! (v, K ' , u)
218+ @. v = ν / v
219+ mul! (tmp1, K, v)
167220 end
168221 end
169222
170223 if ! isconverged
171- @warn " Sinkhorn algorithm did not converge "
224+ @warn " Sinkhorn algorithm ( $maxiter / $maxiter ): not converged "
172225 end
173226
174227 return u, v
175228end
176229
177230"""
178- sinkhorn(μ, ν, C, ε; tol=1e-9, check_marginal_step=10, maxiter=1_000)
231+ sinkhorn(
232+ μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
233+ )
179234
180- Compute the optimal transport plan for the entropic regularization optimal transport problem
181- with source and target marginals `μ` and `ν`, cost matrix `C` of size
235+ Compute the optimal transport plan for the entropically regularized optimal transport
236+ problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
182237`(length(μ), length(ν))`, and entropic regularization parameter `ε`.
183238
184239The optimal transport plan `γ` is of the same size as `C` and solves
@@ -189,28 +244,35 @@ The optimal transport plan `γ` is of the same size as `C` and solves
189244where ``\\ Omega(\\ gamma) = \\ sum_{i,j} \\ gamma_{i,j} \\ log \\ gamma_{i,j}`` is the entropic
190245regularization term.
191246
192- Every `check_marginal_step` steps a convergence check of the error of the marginal
193- `μ` with absolute tolerance `tol` is performed. After `maxiter` iterations, the
194- computation is stopped.
247+ Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
248+ the iterate of the transport plan `G` satisfies
249+ ```julia
250+ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
251+ ```
252+ The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
253+ the computation is stopped.
254+
255+ See also: [`sinkhorn2`](@ref)
195256"""
196- function sinkhorn (mu, nu , C, eps ; kwargs... )
257+ function sinkhorn (μ, ν , C, ε ; kwargs... )
197258 # compute Gibbs kernel
198- K = @. exp (- C / eps )
259+ K = @. exp (- C / ε )
199260
200261 # compute dual potentials
201- u, v = sinkhorn_gibbs (mu, nu , K; kwargs... )
262+ u, v = sinkhorn_gibbs (μ, ν , K; kwargs... )
202263
203- return Diagonal (u) * K * Diagonal (v)
264+ return K .* u .* v '
204265end
205266
206267"""
207268 sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
208269
209- Solve the entropic regularization optimal transport problem with source and target
270+ Solve the entropically regularized optimal transport problem with source and target
210271marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic
211272regularization parameter `ε`, and return the optimal cost.
212273
213- A pre-computed optimal transport `plan` may be provided.
274+ A pre-computed optimal transport `plan` may be provided. The other keyword arguments
275+ supported here are the same as those in the [`sinkhorn`](@ref) function.
214276
215277!!! note
216278 As the `sinkhorn2` function in the Python Optimal Transport package, this function
0 commit comments