Skip to content

Commit 2eee9f4

Browse files
Reorganize sinkhorn_unbalanced, improve convergence checks, and fix GPU issues (#80)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent dfcc088 commit 2eee9f4

File tree

3 files changed

+286
-68
lines changed

3 files changed

+286
-68
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 191 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -304,97 +304,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
304304
end
305305

306306
"""
307-
sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing)
307+
sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...)
308308
309-
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
310-
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
311-
for the marginals `mu`, `nu` respectively.
309+
Compute the optimal transport plan for the unbalanced entropically regularized optimal
310+
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
311+
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
312+
terms `λ1` and `λ2`.
312313
313-
For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem
314+
The optimal transport plan `γ` is of the same size as `C` and solves
315+
```math
316+
\\inf_{\\gamma} \\langle \\gamma, C \\rangle
317+
+ \\varepsilon \\Omega(\\gamma)
318+
+ \\lambda_1 \\operatorname{KL}(\\gamma 1 | \\mu)
319+
+ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1 | \\nu),
320+
```
321+
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
322+
regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence.
323+
324+
The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced`
325+
for unbalanced optimal transport problems with general soft marginal constraints.
326+
"""
327+
function sinkhorn_unbalanced(
328+
μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs...
329+
)
330+
if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing
331+
Base.depwarn(
332+
"keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated",
333+
:sinkhorn_unbalanced,
334+
)
314335

336+
# have to wrap the "proxdiv" functions since the signature changed
337+
# ε was fixed in the function, so we ignore it
338+
proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p))
339+
proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p))
340+
341+
return sinkhorn_unbalanced(
342+
μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs...
343+
)
344+
end
345+
346+
# define "proxdiv" functions for the unbalanced OT problem
347+
proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^/+ λ)))
348+
proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1)
349+
proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2)
350+
351+
return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...)
352+
end
353+
354+
"""
355+
sinkhorn_unbalanced(
356+
μ, ν, C, proxdivF1!, proxdivF2!, ε;
357+
atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000,
358+
)
359+
360+
Compute the optimal transport plan for the unbalanced entropically regularized optimal
361+
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
362+
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal
363+
constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`.
364+
365+
The optimal transport plan `γ` is of the same size as `C` and solves
366+
```math
367+
\\inf_{\\gamma} \\langle \\gamma, C \\rangle
368+
+ \\varepsilon \\Omega(\\gamma)
369+
+ F_1(\\gamma 1, \\mu)
370+
+ F_2(\\gamma^{\\mathsf{T}} 1, \\nu),
371+
```
372+
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
373+
regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)`` are soft marginal
374+
constraints for the source and target marginals.
375+
376+
The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv"
377+
functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic
378+
regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first
379+
argument `s` with the result of their computations.
380+
381+
Mathematically, the "proxdiv" functions are defined as
382+
```math
383+
\\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon)
384+
= \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s
385+
```
386+
where ``\\oslash`` denotes element-wise division and
387+
``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal
388+
operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler
389+
(``\\operatorname{KL}``) divergence. It is defined as
390+
```math
391+
\\operatorname{prox}_{F}^{\\operatorname{KL}}(x)
392+
= \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x)
393+
```
394+
and can be computed in closed-form for specific choices of ``F``. For instance, if
395+
``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then
315396
```math
316-
\\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu)
397+
\\operatorname{prox}_{F(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}(x)
398+
= x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}},
317399
```
400+
where all operators are acting pointwise.[^CPSV18]
401+
402+
Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
403+
the iterates of the scaling factor in the current and previous iteration satisfy
404+
`isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the
405+
current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on
406+
the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped.
318407
319-
via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence).
408+
[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609.
409+
410+
See also: [`sinkhorn_unbalanced2`](@ref)
320411
"""
321412
function sinkhorn_unbalanced(
322-
mu,
323-
nu,
413+
μ,
414+
ν,
324415
C,
325-
lambda1,
326-
lambda2,
327-
eps;
328-
tol=1e-9,
329-
max_iter=1000,
330-
verbose=false,
331-
proxdiv_F1=nothing,
332-
proxdiv_F2=nothing,
416+
proxdivF1!,
417+
proxdivF2!,
418+
ε;
419+
tol=nothing,
420+
atol=tol,
421+
rtol=nothing,
422+
max_iter=nothing,
423+
maxiter=max_iter,
424+
check_convergence::Int=10,
333425
)
334-
function proxdiv_KL(s, eps, lambda, p)
335-
return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s
426+
# deprecations
427+
if tol !== nothing
428+
Base.depwarn(
429+
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
430+
:sinkhorn_unbalanced,
431+
)
336432
end
433+
if max_iter !== nothing
434+
Base.depwarn(
435+
"keyword argument `max_iter` is deprecated, please use `maxiter`",
436+
:sinkhorn_unbalanced,
437+
)
438+
end
439+
440+
# compute Gibbs kernel
441+
K = @. exp(-C / ε)
337442

338-
a = ones(size(mu, 1))
339-
b = ones(size(nu, 1))
340-
a_old = a
341-
b_old = b
342-
tmp_a = zeros(size(nu, 1))
343-
tmp_b = zeros(size(mu, 1))
443+
# set default values of squared tolerances
444+
T = float(Base.promote_eltype(μ, ν, K))
445+
sqatol = atol === nothing ? 0 : atol^2
446+
sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2
344447

345-
K = @. exp(-C / eps)
448+
# initialize iterates
449+
a = similar(μ, T)
450+
sum!(a, K)
451+
proxdivF1!(a, μ, ε)
452+
b = similar(ν, T)
453+
mul!(b, K', a)
454+
proxdivF2!(b, ν, ε)
346455

347-
iter = 1
456+
# caches for convergence checks
457+
a_old = similar(a)
458+
b_old = similar(b)
348459

349-
while true
350-
a_old = a
351-
b_old = b
352-
tmp_b = K * b
353-
if proxdiv_F1 == nothing
354-
a = proxdiv_KL(tmp_b, eps, lambda1, mu)
355-
else
356-
a = proxdiv_F1(tmp_b, mu)
357-
end
358-
tmp_a = K' * a
359-
if proxdiv_F2 == nothing
360-
b = proxdiv_KL(tmp_a, eps, lambda2, nu)
361-
else
362-
b = proxdiv_F2(tmp_a, nu)
460+
isconverged = false
461+
_maxiter = maxiter === nothing ? 1_000 : maxiter
462+
for iter in 1:_maxiter
463+
# update cache if necessary
464+
ischeck = iter % check_convergence == 0
465+
if ischeck
466+
copyto!(a_old, a)
467+
copyto!(b_old, b)
363468
end
364-
iter += 1
365-
if iter % 10 == 0
366-
err_a =
367-
maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1)
368-
err_b =
369-
maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1)
370-
if verbose
371-
println("Iteration $iter, err = ", 0.5 * (err_a + err_b))
372-
end
373-
if (0.5 * (err_a + err_b) < tol) || iter > max_iter
469+
470+
# compute next iterates
471+
mul!(a, K, b)
472+
proxdivF1!(a, μ, ε)
473+
mul!(b, K', a)
474+
proxdivF2!(b, ν, ε)
475+
476+
# check convergence of the scaling factors
477+
if ischeck
478+
# compute norm of current and previous scaling factors and their difference
479+
sqnorm_a_b = sum(abs2, a) + sum(abs2, b)
480+
sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old)
481+
a_old .-= a
482+
b_old .-= b
483+
sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old)
484+
@debug "Sinkhorn algorithm (" *
485+
string(iter) *
486+
"/" *
487+
string(_maxiter) *
488+
": squared Euclidean distance of iterates = " *
489+
string(sqeuclidean_a_b)
490+
491+
# check convergence of `a`
492+
if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old))
493+
@debug "Sinkhorn algorithm ($iter/$_maxiter): converged"
494+
isconverged = true
374495
break
375496
end
376497
end
377498
end
378-
if iter > max_iter && verbose
379-
println("Warning: exited before convergence")
499+
500+
if !isconverged
501+
@warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged"
380502
end
381-
return Diagonal(a) * K * Diagonal(b)
503+
504+
return K .* a .* b'
382505
end
383506

384507
"""
385-
sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...)
508+
sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...)
509+
sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...)
386510
387-
Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
388-
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
389-
for the marginals mu, nu respectively.
511+
Compute the optimal transport plan for the unbalanced entropically regularized optimal
512+
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
513+
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
514+
terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and
515+
`proxdivF2!`.
390516
391-
A pre-computed optimal transport `plan` may be provided.
517+
A pre-computed optimal transport `plan` may be provided. The other keyword arguments
518+
supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced
519+
optimal transport problems with general soft marginal constraints.
392520
393521
See also: [`sinkhorn_unbalanced`](@ref)
394522
"""
395-
function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...)
523+
function sinkhorn_unbalanced2(
524+
μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs...
525+
)
396526
γ = if plan === nothing
397-
sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...)
527+
sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...)
398528
else
399529
# check dimensions
400530
size(C) == (length(μ), length(ν)) ||

0 commit comments

Comments
 (0)