-
Notifications
You must be signed in to change notification settings - Fork 12
Reorganize sinkhorn_unbalanced, improve convergence checks, and fix GPU issues
#80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
4aa431c
55a6edd
7b5e2cf
b50289c
cb14d9e
ebee41d
6acb85e
eabbaad
85f1c30
c587490
980e9ab
9585ccd
31dc3f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| name = "OptimalTransport" | ||
| uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" | ||
| authors = ["zsteve <[email protected]>"] | ||
| version = "0.3.2" | ||
| version = "0.3.3" | ||
|
|
||
| [deps] | ||
| Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -239,97 +239,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...) | |
| end | ||
|
|
||
| """ | ||
| sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing) | ||
| sinkhorn_unbalanced( | ||
| μ, ν, C, λ1::Real, λ2::Real, ε; | ||
| atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, | ||
| ) | ||
|
|
||
| Compute the optimal transport plan for the unbalanced entropic regularization optimal | ||
| transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
| `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation | ||
| terms `λ1` and `λ2`. | ||
|
|
||
| Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, | ||
| using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` | ||
| for the marginals `mu`, `nu` respectively. | ||
| The optimal transport plan `γ` is of the same size as `C` and solves | ||
| ```math | ||
| \\inf_{\\gamma} \\langle \\gamma, C \\rangle | ||
| + \\varepsilon \\Omega(\\gamma) | ||
| + \\lambda_1 \\operatorname{KL}(\\gamma 1, \\mu) | ||
| + \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1, \\nu), | ||
| ``` | ||
| where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic | ||
| regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. | ||
|
|
||
| For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem | ||
| Every `check_convergence` steps a convergence check of the error of the scaling factors | ||
| with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default | ||
| `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the | ||
| computation is stopped. | ||
| """ | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| function sinkhorn_unbalanced( | ||
| μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs... | ||
| ) | ||
| if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing | ||
| Base.depwarn( | ||
| "keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated", | ||
| :sinkhorn_unbalanced, | ||
| ) | ||
|
|
||
| # have to wrap the "proxdiv" functions since the signature changed | ||
| # ε was fixed in the function, so we ignore it | ||
| proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p)) | ||
| proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p)) | ||
|
|
||
| return sinkhorn_unbalanced( | ||
| μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs... | ||
| ) | ||
| end | ||
|
|
||
| # define "proxdiv" functions for the unbalanced OT problem | ||
| proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) | ||
| proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1) | ||
| proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2) | ||
|
|
||
| return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...) | ||
| end | ||
|
|
||
| """ | ||
| sinkhorn_unbalanced( | ||
| μ, ν, C, proxdivF1!, proxdivF2!, ε; | ||
| atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, | ||
| ) | ||
|
|
||
| Compute the optimal transport plan for the unbalanced entropic regularization optimal | ||
| transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
| `(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal | ||
| constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`. | ||
|
|
||
| The optimal transport plan `γ` is of the same size as `C` and solves | ||
| ```math | ||
| \\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu) | ||
| \\inf_{\\gamma} \\langle \\gamma, C \\rangle | ||
| + \\varepsilon \\Omega(\\gamma) | ||
| + F_1(\\gamma 1, \\mu) | ||
| + F_2(\\gamma^{\\mathsf{T}} 1, \\nu), | ||
| ``` | ||
| where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic | ||
| regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)` are soft marginal | ||
| constraints for the source and target marginals. | ||
|
|
||
| via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence). | ||
| The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv" | ||
| functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic | ||
| regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first | ||
| argument `s` with the result of their computations. | ||
|
|
||
| Mathematically, the "proxdiv" functions are defined as | ||
| ```math | ||
| \\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon) | ||
| = \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s | ||
| ``` | ||
| where ``\\oslash`` denotes element-wise division and | ||
| ``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal | ||
| operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler | ||
| (``\\operatorname{KL}``) divergence. It is defined as | ||
| ```math | ||
| \\operatorname{prox}_{F}^{\\operatorname{KL}}(x) | ||
| = \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x) | ||
| ``` | ||
| and can be computed in closed-form for specific choices of ``F``. For instance, if | ||
| ``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then | ||
| ```math | ||
| \\operatorname{prox}_{F(\\cdot | p)/\\varepsilon}^{\\operatorname{KL}}(x) | ||
| = x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}}, | ||
| ``` | ||
| where all operators are acting pointwise.[^CPSV18] | ||
|
|
||
| Every `check_convergence` steps a convergence check of the error of the scaling factors | ||
| with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default | ||
| `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the | ||
| computation is stopped. | ||
|
|
||
| [^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. | ||
| """ | ||
| function sinkhorn_unbalanced( | ||
| mu, | ||
| nu, | ||
| μ, | ||
| ν, | ||
| C, | ||
| lambda1, | ||
| lambda2, | ||
| eps; | ||
| tol=1e-9, | ||
| max_iter=1000, | ||
| verbose=false, | ||
| proxdiv_F1=nothing, | ||
| proxdiv_F2=nothing, | ||
| proxdivF1!, | ||
| proxdivF2!, | ||
| ε; | ||
| tol=nothing, | ||
| atol=tol, | ||
| rtol=nothing, | ||
| max_iter=nothing, | ||
| maxiter=max_iter, | ||
| check_convergence::Int=10, | ||
| ) | ||
| function proxdiv_KL(s, eps, lambda, p) | ||
| return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s | ||
| # deprecations | ||
| if tol !== nothing | ||
| Base.depwarn( | ||
| "keyword argument `tol` is deprecated, please use `atol` and `rtol`", | ||
| :sinkhorn_unbalanced, | ||
| ) | ||
| end | ||
| if max_iter !== nothing | ||
| Base.depwarn( | ||
| "keyword argument `max_iter` is deprecated, please use `maxiter`", | ||
| :sinkhorn_unbalanced, | ||
| ) | ||
| end | ||
|
|
||
| a = ones(size(mu, 1)) | ||
| b = ones(size(nu, 1)) | ||
| a_old = a | ||
| b_old = b | ||
| tmp_a = zeros(size(nu, 1)) | ||
| tmp_b = zeros(size(mu, 1)) | ||
| # compute Gibbs kernel | ||
| K = @. exp(-C / ε) | ||
|
|
||
| K = @. exp(-C / eps) | ||
| # set default values of squared tolerances | ||
| T = float(Base.promote_eltype(μ, ν, K)) | ||
| sqatol = atol === nothing ? 0 : atol^2 | ||
| sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2 | ||
|
|
||
| iter = 1 | ||
| # initialize iterates | ||
| a = similar(μ, T) | ||
| sum!(a, K) | ||
| proxdivF1!(a, μ, ε) | ||
| b = similar(ν, T) | ||
| mul!(b, K', a) | ||
| proxdivF2!(b, ν, ε) | ||
|
|
||
| while true | ||
| a_old = a | ||
| b_old = b | ||
| tmp_b = K * b | ||
| if proxdiv_F1 == nothing | ||
| a = proxdiv_KL(tmp_b, eps, lambda1, mu) | ||
| else | ||
| a = proxdiv_F1(tmp_b, mu) | ||
| end | ||
| tmp_a = K' * a | ||
| if proxdiv_F2 == nothing | ||
| b = proxdiv_KL(tmp_a, eps, lambda2, nu) | ||
| else | ||
| b = proxdiv_F2(tmp_a, nu) | ||
| # caches for convergence checks | ||
| a_old = similar(a) | ||
| b_old = similar(b) | ||
|
|
||
| isconverged = false | ||
| _maxiter = maxiter === nothing ? 1_000 : maxiter | ||
| for iter in 1:_maxiter | ||
| # update cache if necessary | ||
| ischeck = iter % check_convergence == 0 | ||
| if ischeck | ||
| copyto!(a_old, a) | ||
| copyto!(b_old, b) | ||
| end | ||
| iter += 1 | ||
| if iter % 10 == 0 | ||
| err_a = | ||
| maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1) | ||
| err_b = | ||
| maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1) | ||
| if verbose | ||
| println("Iteration $iter, err = ", 0.5 * (err_a + err_b)) | ||
| end | ||
| if (0.5 * (err_a + err_b) < tol) || iter > max_iter | ||
|
|
||
| # compute next iterates | ||
| mul!(a, K, b) | ||
| proxdivF1!(a, μ, ε) | ||
| mul!(b, K', a) | ||
| proxdivF2!(b, ν, ε) | ||
|
|
||
| # check convergence of the scaling factors | ||
| if ischeck | ||
| # compute norm of current and previous scaling factors and their difference | ||
| sqnorm_a_b = sum(abs2, a) + sum(abs2, b) | ||
| sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old) | ||
| a_old .-= a | ||
| b_old .-= b | ||
| sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old) | ||
| @debug "Sinkhorn algorithm (" * | ||
| string(iter) * | ||
| "/" * | ||
| string(_maxiter) * | ||
| ": squared Euclidean distance of iterates = " * | ||
| string(sqeuclidean_a_b) | ||
|
|
||
| # check convergence of `a` | ||
| if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old)) | ||
| @debug "Sinkhorn algorithm ($iter/$_maxiter): converged" | ||
| isconverged = true | ||
| break | ||
| end | ||
| end | ||
| end | ||
| if iter > max_iter && verbose | ||
| println("Warning: exited before convergence") | ||
|
|
||
| if !isconverged | ||
| @warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged" | ||
| end | ||
| return Diagonal(a) * K * Diagonal(b) | ||
|
|
||
| return K .* a .* b' | ||
| end | ||
|
|
||
| """ | ||
| sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...) | ||
| sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) | ||
| sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) | ||
|
|
||
| Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, | ||
| using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` | ||
| for the marginals mu, nu respectively. | ||
| Compute the optimal transport plan for the unbalanced entropic regularization optimal | ||
| transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
| `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation | ||
| terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and | ||
| `proxdivF2!`. | ||
|
|
||
| A pre-computed optimal transport `plan` may be provided. | ||
|
|
||
| See also: [`sinkhorn_unbalanced`](@ref) | ||
| """ | ||
| function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) | ||
| function sinkhorn_unbalanced2( | ||
| μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs... | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we rename to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe. I still think the main problem is not the name of the functions but the amount of functions - they are all doing the same thing but for slightly different problems (many even for the same) and different algorithms. So the natural approach would be to be able to dispatch both on the problem and the algorithm, which would also solve the problem that #66 tries to address but doesn't fix in a general and extendable way.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, I would suggest that both renaming and reorganization of functions should be done in a separate PR.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, you are right. It's better that we decide on the naming, and then submit in another PR.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a problem since if we have different functions for every combination of problem and algorithm (such as e.g.
|
||
| γ = if plan === nothing | ||
| sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...) | ||
| sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...) | ||
| else | ||
| # check dimensions | ||
| size(C) == (length(μ), length(ν)) || | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.