Skip to content

Commit 15e296a

Browse files
authored
Add option to return cost with regularization term (#75)
1 parent 0907838 commit 15e296a

File tree

3 files changed

+54
-27
lines changed

3 files changed

+54
-27
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1011
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1112
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1213

1314
[compat]
1415
Distances = "0.9.0, 0.10"
1516
IterativeSolvers = "0.8.4, 0.9"
17+
LogExpFunctions = "0.2"
1618
MathOptInterface = "0.9"
1719
julia = "1"
1820

src/OptimalTransport.jl

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module OptimalTransport
77
using Distances
88
using LinearAlgebra
99
using IterativeSolvers, SparseArrays
10+
using LogExpFunctions: LogExpFunctions
1011
using MathOptInterface
1112

1213
export sinkhorn, sinkhorn2
@@ -171,19 +172,23 @@ function sinkhorn_gibbs(mu, nu, K; tol=1e-9, check_marginal_step=10, maxiter=100
171172
end
172173

173174
"""
174-
sinkhorn(mu, nu, C, eps; tol=1e-9, check_marginal_step=10, maxiter=1000)
175+
sinkhorn(μ, ν, C, ε; tol=1e-9, check_marginal_step=10, maxiter=1_000)
175176
176-
Compute entropically regularised transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic
177-
regularization parameter `eps`.
178-
179-
Return optimal transport coupling `γ` of the same dimensions as `C` which solves
177+
Compute the optimal transport plan for the entropic regularization optimal transport problem
178+
with source and target marginals `μ` and `ν`, cost matrix `C` of size
179+
`(length(μ), length(ν))`, and entropic regularization parameter `ε`.
180180
181+
The optimal transport plan `γ` is of the same size as `C` and solves
181182
```math
182-
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle - \\epsilon H(\\gamma)
183+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
184+
+ \\varepsilon \\Omega(\\gamma),
183185
```
186+
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
187+
regularization term.
184188
185-
where ``H`` is the entropic regulariser, ``H(\\gamma) = -\\sum_{i, j} \\gamma_{ij} \\log(\\gamma_{ij})``.
186-
189+
Every `check_marginal_step` steps a convergence check of the error of the marginal
190+
`μ` with absolute tolerance `tol` is performed. After `maxiter` iterations, the
191+
computation is stopped.
187192
"""
188193
function sinkhorn(mu, nu, C, eps; kwargs...)
189194
# compute Gibbs kernel
@@ -196,24 +201,22 @@ function sinkhorn(mu, nu, C, eps; kwargs...)
196201
end
197202

198203
"""
199-
sinkhorn2(mu, nu, C, eps; plan=nothing, kwargs...)
200-
201-
Compute entropically regularised transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic
202-
regularization parameter `eps`.
203-
204-
Return optimal value of
205-
206-
```math
207-
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle - \\epsilon H(\\gamma)
208-
```
204+
sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
209205
210-
where ``H`` is the entropic regulariser, ``H(\\gamma) = -\\sum_{i, j} \\gamma_{ij} \\log(\\gamma_{ij})``.
206+
Solve the entropic regularization optimal transport problem with source and target
207+
marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic
208+
regularization parameter `ε`, and return the optimal cost.
211209
212210
A pre-computed optimal transport `plan` may be provided.
213211
212+
!!! note
213+
As the `sinkhorn2` function in the Python Optimal Transport package, this function
214+
returns the optimal transport cost without the regularization term. The cost
215+
with the regularization term can be computed by setting `regularization=true`.
216+
214217
See also: [`sinkhorn`](@ref)
215218
"""
216-
function sinkhorn2(μ, ν, C, ε; plan=nothing, kwargs...)
219+
function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
217220
γ = if plan === nothing
218221
sinkhorn(μ, ν, C, ε; kwargs...)
219222
else
@@ -225,7 +228,14 @@ function sinkhorn2(μ, ν, C, ε; plan=nothing, kwargs...)
225228
)
226229
plan
227230
end
228-
return dot(γ, C)
231+
232+
cost = if regularization
233+
dot(γ, C) + ε * sum(LogExpFunctions.xlogx, γ)
234+
else
235+
dot(γ, C)
236+
end
237+
238+
return cost
229239
end
230240

231241
"""

test/runtests.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,24 @@ end
8282
γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)
8383
@test norm- γ_pot, Inf) < 1e-9
8484

85-
# compute optimal transport cost (Julia implementation + POT)
85+
# compute optimal transport cost
8686
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000)
87+
88+
# with regularization term
89+
c_w_regularization = sinkhorn2(μ, ν, C, eps; maxiter=5_000, regularization=true)
90+
@test c_w_regularization c + eps * sum(x -> iszero(x) ? x : x * log(x), γ)
91+
92+
# compare with POT
8793
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)[1]
88-
@test c c_pot atol = 1e-9
94+
@test c_pot c atol = 1e-9
8995

90-
# ensure that provided map is used
96+
# ensure that provided map is used and correct
9197
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ)
9298
@test c2 c
99+
c2_w_regularization = sinkhorn2(
100+
similar(μ), similar(ν), C, eps; plan=γ, regularization=true
101+
)
102+
@test c2_w_regularization c_w_regularization
93103
end
94104

95105
# different element type
@@ -109,12 +119,17 @@ end
109119
γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)
110120
@test norm- γ_pot, Inf) < Base.eps(Float32)
111121

112-
# compute optimal transport cost (Julia implementation + POT)
122+
# compute optimal transport cost
113123
c = sinkhorn2(μ, ν, C, eps; maxiter=5_000)
114124
@test c isa Float32
115125

126+
# with regularization term
127+
c_w_regularization = sinkhorn2(μ, ν, C, eps; maxiter=5_000, regularization=true)
128+
@test c_w_regularization c + eps * sum(x -> iszero(x) ? x : x * log(x), γ)
129+
130+
# compare with POT
116131
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)[1]
117-
@test c c_pot atol = Base.eps(Float32)
132+
@test c_pot c atol = Base.eps(Float32)
118133
end
119134
end
120135

0 commit comments

Comments
 (0)