@@ -304,97 +304,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
304304end
305305
306306"""
307- sinkhorn_unbalanced(mu, nu , C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing )
307+ sinkhorn_unbalanced(μ, ν , C, λ1::Real, λ2::Real, ε; kwargs... )
308308
309- Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
310- using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
311- for the marginals `mu`, `nu` respectively.
309+ Compute the optimal transport plan for the unbalanced entropically regularized optimal
310+ transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
311+ `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
312+ terms `λ1` and `λ2`.
312313
313- For full generality, the user can specify the soft marginal constraints ``(F_1(\\ cdot | \\ mu), F_2(\\ cdot | \\ nu))`` to the problem
314+ The optimal transport plan `γ` is of the same size as `C` and solves
315+ ```math
316+ \\ inf_{\\ gamma} \\ langle \\ gamma, C \\ rangle
317+ + \\ varepsilon \\ Omega(\\ gamma)
318+ + \\ lambda_1 \\ operatorname{KL}(\\ gamma 1 | \\ mu)
319+ + \\ lambda_2 \\ operatorname{KL}(\\ gamma^{\\ mathsf{T}} 1 | \\ nu),
320+ ```
321+ where ``\\ Omega(\\ gamma) = \\ sum_{i,j} \\ gamma_{i,j} \\ log \\ gamma_{i,j}`` is the entropic
322+ regularization term and ``\\ operatorname{KL}`` is the Kullback-Leibler divergence.
323+
324+ The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced`
325+ for unbalanced optimal transport problems with general soft marginal constraints.
326+ """
327+ function sinkhorn_unbalanced (
328+ μ, ν, C, λ1:: Real , λ2:: Real , ε; proxdiv_F1= nothing , proxdiv_F2= nothing , kwargs...
329+ )
330+ if proxdiv_F1 != = nothing && proxdiv_F2 != = nothing
331+ Base. depwarn (
332+ " keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated" ,
333+ :sinkhorn_unbalanced ,
334+ )
314335
336+ # have to wrap the "proxdiv" functions since the signature changed
337+ # ε was fixed in the function, so we ignore it
338+ proxdiv_F1_wrapper (s, p, _) = copyto! (s, proxdiv_F1 (s, p))
339+ proxdiv_F2_wrapper (s, p, _) = copyto! (s, proxdiv_F2 (s, p))
340+
341+ return sinkhorn_unbalanced (
342+ μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs...
343+ )
344+ end
345+
346+ # define "proxdiv" functions for the unbalanced OT problem
347+ proxdivF! (s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ)))
348+ proxdivF1! (s, p, ε) = proxdivF! (s, p, ε, λ1)
349+ proxdivF2! (s, p, ε) = proxdivF! (s, p, ε, λ2)
350+
351+ return sinkhorn_unbalanced (μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs... )
352+ end
353+
354+ """
355+ sinkhorn_unbalanced(
356+ μ, ν, C, proxdivF1!, proxdivF2!, ε;
357+ atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000,
358+ )
359+
360+ Compute the optimal transport plan for the unbalanced entropically regularized optimal
361+ transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
362+ `(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal
363+ constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`.
364+
365+ The optimal transport plan `γ` is of the same size as `C` and solves
366+ ```math
367+ \\ inf_{\\ gamma} \\ langle \\ gamma, C \\ rangle
368+ + \\ varepsilon \\ Omega(\\ gamma)
369+ + F_1(\\ gamma 1, \\ mu)
370+ + F_2(\\ gamma^{\\ mathsf{T}} 1, \\ nu),
371+ ```
372+ where ``\\ Omega(\\ gamma) = \\ sum_{i,j} \\ gamma_{i,j} \\ log \\ gamma_{i,j}`` is the entropic
373+ regularization term and ``F_1(\\ cdot, \\ mu)`` and ``F_2(\\ cdot, \\ nu)`` are soft marginal
374+ constraints for the source and target marginals.
375+
376+ The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv"
377+ functions of ``F_1(\\ cdot, p)`` and ``F_2(\\ cdot, p)`` at ``s`` for the entropic
378+ regularization parameter ``\\ varepsilon``. They have to be mutating and overwrite the first
379+ argument `s` with the result of their computations.
380+
381+ Mathematically, the "proxdiv" functions are defined as
382+ ```math
383+ \\ operatorname{proxdiv}_{F_i}(s, p, \\ varepsilon)
384+ = \\ operatorname{prox}^{\\ operatorname{KL}}_{F_i(\\ cdot, p)/\\ varepsilon}(s) \\ oslash s
385+ ```
386+ where ``\\ oslash`` denotes element-wise division and
387+ ``\\ operatorname{prox}_{F_i(\\ cdot, p)/\\ varepsilon}^{\\ operatorname{KL}}`` is the proximal
388+ operator of ``F_i(\\ cdot, p)/\\ varepsilon`` for the Kullback-Leibler
389+ (``\\ operatorname{KL}``) divergence. It is defined as
390+ ```math
391+ \\ operatorname{prox}_{F}^{\\ operatorname{KL}}(x)
392+ = \\ operatorname{argmin}_{y} F(y) + \\ operatorname{KL}(y|x)
393+ ```
394+ and can be computed in closed-form for specific choices of ``F``. For instance, if
395+ ``F(\\ cdot, p) = \\ lambda \\ operatorname{KL}(\\ cdot | p)`` (``\\ lambda > 0``), then
315396```math
316- \\ min_\\ gamma \\ epsilon \\ mathrm{KL}(\\ gamma | \\ exp(-C/\\ epsilon)) + F_1(\\ gamma_1 | \\ mu) + F_2(\\ gamma_2 | \\ nu)
397+ \\ operatorname{prox}_{F(\\ cdot, p)/\\ varepsilon}^{\\ operatorname{KL}}(x)
398+ = x^{\\ frac{\\ varepsilon}{\\ varepsilon + \\ lambda}} p^{\\ frac{\\ lambda}{\\ varepsilon + \\ lambda}},
317399```
400+ where all operators are acting pointwise.[^CPSV18]
401+
402+ Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
403+ the iterates of the scaling factor in the current and previous iteration satisfy
404+ `isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the
405+ current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on
406+ the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped.
318407
319- via `math\\ mathrm{proxdiv}_{F_1}(s, p)` and `math\\ mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence).
408+ [^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609.
409+
410+ See also: [`sinkhorn_unbalanced2`](@ref)
320411"""
321412function sinkhorn_unbalanced (
322- mu ,
323- nu ,
413+ μ ,
414+ ν ,
324415 C,
325- lambda1,
326- lambda2,
327- eps;
328- tol= 1e-9 ,
329- max_iter= 1000 ,
330- verbose= false ,
331- proxdiv_F1= nothing ,
332- proxdiv_F2= nothing ,
416+ proxdivF1!,
417+ proxdivF2!,
418+ ε;
419+ tol= nothing ,
420+ atol= tol,
421+ rtol= nothing ,
422+ max_iter= nothing ,
423+ maxiter= max_iter,
424+ check_convergence:: Int = 10 ,
333425)
334- function proxdiv_KL (s, eps, lambda, p)
335- return @. (s^ (eps / (eps + lambda)) * p^ (lambda / (eps + lambda))) / s
426+ # deprecations
427+ if tol != = nothing
428+ Base. depwarn (
429+ " keyword argument `tol` is deprecated, please use `atol` and `rtol`" ,
430+ :sinkhorn_unbalanced ,
431+ )
336432 end
433+ if max_iter != = nothing
434+ Base. depwarn (
435+ " keyword argument `max_iter` is deprecated, please use `maxiter`" ,
436+ :sinkhorn_unbalanced ,
437+ )
438+ end
439+
440+ # compute Gibbs kernel
441+ K = @. exp (- C / ε)
337442
338- a = ones (size (mu, 1 ))
339- b = ones (size (nu, 1 ))
340- a_old = a
341- b_old = b
342- tmp_a = zeros (size (nu, 1 ))
343- tmp_b = zeros (size (mu, 1 ))
443+ # set default values of squared tolerances
444+ T = float (Base. promote_eltype (μ, ν, K))
445+ sqatol = atol === nothing ? 0 : atol^ 2
446+ sqrtol = rtol === nothing ? (sqatol > zero (sqatol) ? zero (T) : eps (T)) : rtol^ 2
344447
345- K = @. exp (- C / eps)
448+ # initialize iterates
449+ a = similar (μ, T)
450+ sum! (a, K)
451+ proxdivF1! (a, μ, ε)
452+ b = similar (ν, T)
453+ mul! (b, K' , a)
454+ proxdivF2! (b, ν, ε)
346455
347- iter = 1
456+ # caches for convergence checks
457+ a_old = similar (a)
458+ b_old = similar (b)
348459
349- while true
350- a_old = a
351- b_old = b
352- tmp_b = K * b
353- if proxdiv_F1 == nothing
354- a = proxdiv_KL (tmp_b, eps, lambda1, mu)
355- else
356- a = proxdiv_F1 (tmp_b, mu)
357- end
358- tmp_a = K' * a
359- if proxdiv_F2 == nothing
360- b = proxdiv_KL (tmp_a, eps, lambda2, nu)
361- else
362- b = proxdiv_F2 (tmp_a, nu)
460+ isconverged = false
461+ _maxiter = maxiter === nothing ? 1_000 : maxiter
462+ for iter in 1 : _maxiter
463+ # update cache if necessary
464+ ischeck = iter % check_convergence == 0
465+ if ischeck
466+ copyto! (a_old, a)
467+ copyto! (b_old, b)
363468 end
364- iter += 1
365- if iter % 10 == 0
366- err_a =
367- maximum (abs .(a - a_old)) / max (maximum (abs .(a)), maximum (abs .(a_old)), 1 )
368- err_b =
369- maximum (abs .(b - b_old)) / max (maximum (abs .(b)), maximum (abs .(b_old)), 1 )
370- if verbose
371- println (" Iteration $iter , err = " , 0.5 * (err_a + err_b))
372- end
373- if (0.5 * (err_a + err_b) < tol) || iter > max_iter
469+
470+ # compute next iterates
471+ mul! (a, K, b)
472+ proxdivF1! (a, μ, ε)
473+ mul! (b, K' , a)
474+ proxdivF2! (b, ν, ε)
475+
476+ # check convergence of the scaling factors
477+ if ischeck
478+ # compute norm of current and previous scaling factors and their difference
479+ sqnorm_a_b = sum (abs2, a) + sum (abs2, b)
480+ sqnorm_a_b_old = sum (abs2, a_old) + sum (abs2, b_old)
481+ a_old .- = a
482+ b_old .- = b
483+ sqeuclidean_a_b = sum (abs2, a_old) + sum (abs2, b_old)
484+ @debug " Sinkhorn algorithm (" *
485+ string (iter) *
486+ " /" *
487+ string (_maxiter) *
488+ " : squared Euclidean distance of iterates = " *
489+ string (sqeuclidean_a_b)
490+
491+ # check convergence of `a`
492+ if sqeuclidean_a_b < max (sqatol, sqrtol * max (sqnorm_a_b, sqnorm_a_b_old))
493+ @debug " Sinkhorn algorithm ($iter /$_maxiter ): converged"
494+ isconverged = true
374495 break
375496 end
376497 end
377498 end
378- if iter > max_iter && verbose
379- println (" Warning: exited before convergence" )
499+
500+ if ! isconverged
501+ @warn " Sinkhorn algorithm ($_maxiter /$_maxiter ): not converged"
380502 end
381- return Diagonal (a) * K * Diagonal (b)
503+
504+ return K .* a .* b'
382505end
383506
384507"""
385- sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...)
508+ sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...)
509+ sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...)
386510
387- Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
388- using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
389- for the marginals mu, nu respectively.
511+ Compute the optimal transport plan for the unbalanced entropically regularized optimal
512+ transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
513+ `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
514+ terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and
515+ `proxdivF2!`.
390516
391- A pre-computed optimal transport `plan` may be provided.
517+ A pre-computed optimal transport `plan` may be provided. The other keyword arguments
518+ supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced
519+ optimal transport problems with general soft marginal constraints.
392520
393521See also: [`sinkhorn_unbalanced`](@ref)
394522"""
395- function sinkhorn_unbalanced2 (μ, ν, C, λ1, λ2, ε; plan= nothing , kwargs... )
523+ function sinkhorn_unbalanced2 (
524+ μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan= nothing , kwargs...
525+ )
396526 γ = if plan === nothing
397- sinkhorn_unbalanced (μ, ν, C, λ1, λ2 , ε; kwargs... )
527+ sinkhorn_unbalanced (μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2 , ε; kwargs... )
398528 else
399529 # check dimensions
400530 size (C) == (length (μ), length (ν)) ||
0 commit comments