Skip to content

Commit dfcc088

Browse files
More sophisticated convergence checks in sinkhorn (#79)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 63d0eef commit dfcc088

File tree

4 files changed

+140
-52
lines changed

4 files changed

+140
-52
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ jobs:
2020
os:
2121
- ubuntu-latest
2222
- windows-latest
23-
- macOS-latest
2423
arch:
2524
- x64
2625
include:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,65 +120,120 @@ function emd2(μ, ν, C, optimizer; plan=nothing)
120120
end
121121

122122
"""
123-
sinkhorn_gibbs(mu, nu, K; tol=1e-9, check_marginal_step=10, maxiter=1000)
123+
sinkhorn_gibbs(
124+
μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
125+
)
126+
127+
Compute the dual potentials for the entropically regularized optimal transport problem
128+
with source and target marginals `μ` and `ν` and Gibbs kernel `K` using the Sinkhorn
129+
algorithm.
124130
125-
Compute dual potentials `u` and `v` for histograms `mu` and `nu` and Gibbs kernel `K` using
126-
the Sinkhorn algorithm (Peyre et al., 2019)
131+
The Gibbs kernel `K` is defined as
132+
```math
133+
K = \\exp(-C / \\varepsilon),
134+
```
135+
where ``C`` is the cost matrix and ``\\varepsilon`` the entropic regularization parameter.
136+
The corresponding optimal transport plan can be computed from the dual potentials ``u``
137+
and ``v`` as
138+
```math
139+
\\gamma = \\operatorname{diag}(u) K \\operatorname{diag}(v).
140+
```
127141
128-
The Gibbs kernel `K` is given by `K = exp.(- C / eps)` where `C` is the cost matrix and
129-
`eps` the entropic regularization parameter. The optimal transport plan for histograms `u`
130-
and `v` and cost matrix `C` with regularization parameter `eps` can be computed as
131-
`Diagonal(u) * K * Diagonal(v)`.
142+
Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
143+
the iterate of the transport plan `G` satisfies
144+
```julia
145+
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
146+
```
147+
The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
148+
the computation is stopped.
132149
"""
133-
function sinkhorn_gibbs(mu, nu, K; tol=1e-9, check_marginal_step=10, maxiter=1000)
134-
if !(sum(mu) sum(nu))
135-
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
150+
function sinkhorn_gibbs(
151+
μ,
152+
ν,
153+
K;
154+
tol=nothing,
155+
atol=tol,
156+
rtol=nothing,
157+
check_marginal_step=nothing,
158+
check_convergence=check_marginal_step,
159+
maxiter::Int=1_000,
160+
)
161+
if tol !== nothing
162+
Base.depwarn(
163+
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
164+
:sinkhorn_gibbs,
165+
)
136166
end
167+
if check_marginal_step !== nothing
168+
Base.depwarn(
169+
"keyword argument `check_marginal_step` is deprecated, please use `check_convergence`",
170+
:sinkhorn_gibbs,
171+
)
172+
end
173+
sum(μ) sum(ν) ||
174+
throw(ArgumentError("source and target marginals must have the same mass"))
175+
176+
# set default values of tolerances
177+
T = float(Base.promote_eltype(μ, ν, K))
178+
_atol = atol === nothing ? 0 : atol
179+
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol
137180

138181
# initial iteration
139-
temp_v = vec(sum(K; dims=2))
140-
u = mu ./ temp_v
141-
temp_u = K' * u
142-
v = nu ./ temp_u
182+
u = μ ./ sum(K; dims=2)
183+
v = ν ./ (K' * u)
184+
tmp1 = K * v
185+
tmp2 = similar(u)
143186

187+
norm_μ = sum(abs, μ) # for convergence check
144188
isconverged = false
189+
check_step = check_convergence === nothing ? 10 : check_convergence
145190
for iter in 0:maxiter
146-
# check mu marginal
147-
if iter % check_marginal_step == 0
148-
mul!(temp_v, K, v)
149-
@. temp_v = abs(mu - u * temp_v)
150-
151-
err = maximum(temp_v)
152-
@debug "Sinkhorn algorithm: iteration $iter" err
191+
if iter % check_step == 0
192+
# check source marginal
193+
# do not overwrite `tmp1` but reuse it for computing `u` if not converged
194+
@. tmp2 = u * tmp1
195+
norm_uKv = sum(abs, tmp2)
196+
@. tmp2 = μ - tmp2
197+
norm_diff = sum(abs, tmp2)
198+
199+
@debug "Sinkhorn algorithm (" *
200+
string(iter) *
201+
"/" *
202+
string(maxiter) *
203+
": absolute error of source marginal = " *
204+
string(norm_diff)
153205

154206
# check stopping criterion
155-
if err < tol
207+
if norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
208+
@debug "Sinkhorn algorithm ($iter/$maxiter): converged"
156209
isconverged = true
157210
break
158211
end
159212
end
160213

161214
# perform next iteration
162215
if iter < maxiter
163-
mul!(temp_v, K, v)
164-
@. u = mu / temp_v
165-
mul!(temp_u, K', u)
166-
@. v = nu / temp_u
216+
@. u = μ / tmp1
217+
mul!(v, K', u)
218+
@. v = ν / v
219+
mul!(tmp1, K, v)
167220
end
168221
end
169222

170223
if !isconverged
171-
@warn "Sinkhorn algorithm did not converge"
224+
@warn "Sinkhorn algorithm ($maxiter/$maxiter): not converged"
172225
end
173226

174227
return u, v
175228
end
176229

177230
"""
178-
sinkhorn(μ, ν, C, ε; tol=1e-9, check_marginal_step=10, maxiter=1_000)
231+
sinkhorn(
232+
μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
233+
)
179234
180-
Compute the optimal transport plan for the entropic regularization optimal transport problem
181-
with source and target marginals `μ` and `ν`, cost matrix `C` of size
235+
Compute the optimal transport plan for the entropically regularized optimal transport
236+
problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
182237
`(length(μ), length(ν))`, and entropic regularization parameter `ε`.
183238
184239
The optimal transport plan `γ` is of the same size as `C` and solves
@@ -189,28 +244,35 @@ The optimal transport plan `γ` is of the same size as `C` and solves
189244
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
190245
regularization term.
191246
192-
Every `check_marginal_step` steps a convergence check of the error of the marginal
193-
`μ` with absolute tolerance `tol` is performed. After `maxiter` iterations, the
194-
computation is stopped.
247+
Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
248+
the iterate of the transport plan `G` satisfies
249+
```julia
250+
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
251+
```
252+
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
253+
the computation is stopped.
254+
255+
See also: [`sinkhorn2`](@ref)
195256
"""
196-
function sinkhorn(mu, nu, C, eps; kwargs...)
257+
function sinkhorn(μ, ν, C, ε; kwargs...)
197258
# compute Gibbs kernel
198-
K = @. exp(-C / eps)
259+
K = @. exp(-C / ε)
199260

200261
# compute dual potentials
201-
u, v = sinkhorn_gibbs(mu, nu, K; kwargs...)
262+
u, v = sinkhorn_gibbs(μ, ν, K; kwargs...)
202263

203-
return Diagonal(u) * K * Diagonal(v)
264+
return K .* u .* v'
204265
end
205266

206267
"""
207268
sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
208269
209-
Solve the entropic regularization optimal transport problem with source and target
270+
Solve the entropically regularized optimal transport problem with source and target
210271
marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic
211272
regularization parameter `ε`, and return the optimal cost.
212273
213-
A pre-computed optimal transport `plan` may be provided.
274+
A pre-computed optimal transport `plan` may be provided. The other keyword arguments
275+
supported here are the same as those in the [`sinkhorn`](@ref) function.
214276
215277
!!! note
216278
As the `sinkhorn2` function in the Python Optimal Transport package, this function

test/runtests.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,20 @@ end
124124

125125
# compute optimal transport map (Julia implementation + POT)
126126
eps = 0.01
127-
γ = sinkhorn(μ, ν, C, eps; maxiter=5_000)
127+
γ = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-9)
128128
γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)
129-
@test norm- γ_pot, Inf) < 1e-9
129+
@test γ_pot γ rtol = 1e-6
130130

131131
# compute optimal transport cost
132-
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000)
132+
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000, rtol=1e-9)
133133

134134
# with regularization term
135135
c_w_regularization = sinkhorn2(μ, ν, C, eps; maxiter=5_000, regularization=true)
136136
@test c_w_regularization c + eps * sum(x -> iszero(x) ? x : x * log(x), γ)
137137

138138
# compare with POT
139139
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)[1]
140-
@test c_pot c atol = 1e-9
140+
@test c_pot c
141141

142142
# ensure that provided map is used and correct
143143
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ)
@@ -159,23 +159,50 @@ end
159159

160160
# compute optimal transport map (Julia implementation + POT)
161161
eps = 0.01f0
162-
γ = sinkhorn(μ, ν, C, eps; maxiter=5_000)
162+
γ = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-6)
163163
@test eltype(γ) === Float32
164164

165-
γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)
166-
@test norm- γ_pot, Inf) < Base.eps(Float32)
165+
γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-6)
166+
@test Float32.(γ_pot) γ rtol = 1e-3
167167

168168
# compute optimal transport cost
169-
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000)
169+
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000, rtol=1e-6)
170170
@test c isa Float32
171171

172172
# with regularization term
173-
c_w_regularization = sinkhorn2(μ, ν, C, eps; maxiter=5_000, regularization=true)
173+
c_w_regularization = sinkhorn2(
174+
μ, ν, C, eps; maxiter=5_000, rtol=1e-6, regularization=true
175+
)
174176
@test c_w_regularization c + eps * sum(x -> iszero(x) ? x : x * log(x), γ)
175177

176178
# compare with POT
177-
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)[1]
178-
@test c_pot c atol = Base.eps(Float32)
179+
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-6)[1]
180+
@test Float32(c_pot) c rtol = 1e-3
181+
end
182+
183+
@testset "deprecations" begin
184+
# create two uniform histograms
185+
μ = fill(1 / M, M)
186+
ν = fill(1 / N, N)
187+
188+
# create random cost matrix
189+
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
190+
191+
# check `sinkhorn2`
192+
eps = 0.01
193+
c = sinkhorn2(μ, ν, C, eps; atol=1e-6)
194+
@test (@test_deprecated sinkhorn2(μ, ν, C, eps; tol=1e-6)) == c
195+
c = sinkhorn2(μ, ν, C, eps; check_convergence=5)
196+
@test (@test_deprecated sinkhorn2(μ, ν, C, eps; check_marginal_step=5)) == c
197+
198+
# check `sinkhorn_gibbs
199+
K = @. exp(-C / eps)
200+
γ = OptimalTransport.sinkhorn_gibbs(μ, ν, K; atol=1e-6)
201+
@test (@test_deprecated OptimalTransport.sinkhorn_gibbs(μ, ν, K; tol=1e-6)) == γ
202+
γ = OptimalTransport.sinkhorn_gibbs(μ, ν, K; check_convergence=5)
203+
@test (@test_deprecated OptimalTransport.sinkhorn_gibbs(
204+
μ, ν, K; check_marginal_step=5
205+
)) == γ
179206
end
180207
end
181208

0 commit comments

Comments
 (0)