Skip to content

Commit 6acc912

Browse files
Switch to plan (#61)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 365e4b4 commit 6acc912

File tree

3 files changed

+68
-35
lines changed

3 files changed

+68
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828
"""
2929
emd(μ, ν, C, optimizer)
3030
31-
Compute the optimal transport map `γ` for the Monge-Kantorovich problem with source
31+
Compute the optimal transport plan `γ` for the Monge-Kantorovich problem with source
3232
histogram `μ`, target histogram `ν`, and cost matrix `C` of size `(length(μ), length(ν))`
3333
which solves
3434
```math
@@ -84,15 +84,15 @@ function emd(μ, ν, C, model::MOI.ModelLike)
8484
# compute optimal solution
8585
MOI.optimize!(model)
8686
status = MOI.get(model, MOI.TerminationStatus())
87-
status === MOI.OPTIMAL || error("failed to compute optimal transport map: ", status)
87+
status === MOI.OPTIMAL || error("failed to compute optimal transport plan: ", status)
8888
p = MOI.get(model, MOI.VariablePrimal(), x)
8989
γ = reshape(p, nμ, nν)
9090

9191
return γ
9292
end
9393

9494
"""
95-
emd2(μ, ν, C, optimizer; map=nothing)
95+
emd2(μ, ν, C, optimizer; plan=nothing)
9696
9797
Compute the optimal transport cost (a scalar) for the Monge-Kantorovich problem with source
9898
histogram `μ`, target histogram `ν`, and cost matrix `C` of size `(length(μ), length(ν))`
@@ -105,20 +105,25 @@ The corresponding linear programming problem is solved with the user-provided `o
105105
Possible choices are `Tulip.Optimizer()` and `Clp.Optimizer()` in the `Tulip` and `Clp`
106106
packages, respectively.
107107
108-
A pre-computed optimal transport `map` may be provided.
108+
A pre-computed optimal transport `plan` may be provided.
109109
"""
110-
function emd2(μ, ν, C, optimizer; map=nothing)
111-
γ = if map === nothing
112-
# compute optimal transport map
110+
function emd2(μ, ν, C, optimizer; map=nothing, plan=map)
111+
# check deprecation
112+
if map !== nothing
113+
Base.depwarn("the keyword argument `map` is deprecated, please use `plan`", :emd2)
114+
end
115+
116+
γ = if plan === nothing
117+
# compute optimal transport plan
113118
emd(μ, ν, C, optimizer)
114119
else
115120
# check dimensions
116121
size(C) == (length(μ), length(ν)) ||
117122
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
118-
size(map) == size(C) || error(
119-
"optimal transport map `map` and cost matrix `C` must be of the same size",
123+
size(plan) == size(C) || error(
124+
"optimal transport plan `plan` and cost matrix `C` must be of the same size",
120125
)
121-
map
126+
plan
122127
end
123128
return dot(γ, C)
124129
end
@@ -130,7 +135,7 @@ Compute dual potentials `u` and `v` for histograms `mu` and `nu` and Gibbs kerne
130135
the Sinkhorn algorithm (Peyre et al., 2019)
131136
132137
The Gibbs kernel `K` is given by `K = exp.(- C / eps)` where `C` is the cost matrix and
133-
`eps` the entropic regularization parameter. The optimal transport map for histograms `u`
138+
`eps` the entropic regularization parameter. The optimal transport plan for histograms `u`
134139
and `v` and cost matrix `C` with regularization parameter `eps` can be computed as
135140
`Diagonal(u) * K * Diagonal(v)`.
136141
"""
@@ -181,7 +186,7 @@ end
181186
"""
182187
sinkhorn(mu, nu, C, eps; tol=1e-9, check_marginal_step=10, maxiter=1000)
183188
184-
Compute entropically regularised transport map of histograms `mu` and `nu` with cost matrix `C` and entropic
189+
Compute entropically regularised transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic
185190
regularization parameter `eps`.
186191
187192
Return optimal transport coupling `γ` of the same dimensions as `C` which solves
@@ -204,7 +209,7 @@ function sinkhorn(mu, nu, C, eps; kwargs...)
204209
end
205210

206211
"""
207-
sinkhorn2(mu, nu, C, eps; tol=1e-9, check_marginal_step=10, maxiter=1000)
212+
sinkhorn2(mu, nu, C, eps; plan=nothing, kwargs...)
208213
209214
Compute entropically regularised transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic
210215
regularization parameter `eps`.
@@ -217,27 +222,36 @@ Return optimal value of
217222
218223
where ``H`` is the entropic regulariser, ``H(\\gamma) = -\\sum_{i, j} \\gamma_{ij} \\log(\\gamma_{ij})``.
219224
220-
A pre-computed optimal transport `map` may be provided.
225+
A pre-computed optimal transport `plan` may be provided.
226+
227+
See also: [`sinkhorn`](@ref)
221228
"""
222-
function sinkhorn2(μ, ν, C, ε; map=nothing, kwargs...)
223-
γ = if map === nothing
229+
function sinkhorn2(μ, ν, C, ε; map=nothing, plan=map, kwargs...)
230+
# check deprecation
231+
if map !== nothing
232+
Base.depwarn(
233+
"the keyword argument `map` is deprecated, please use `plan`", :sinkhorn2
234+
)
235+
end
236+
237+
γ = if plan === nothing
224238
sinkhorn(μ, ν, C, ε; kwargs...)
225239
else
226240
# check dimensions
227241
size(C) == (length(μ), length(ν)) ||
228242
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
229-
size(map) == size(C) || error(
230-
"optimal transport map `map` and cost matrix `C` must be of the same size",
243+
size(plan) == size(C) || error(
244+
"optimal transport plan `plan` and cost matrix `C` must be of the same size",
231245
)
232-
map
246+
plan
233247
end
234248
return dot(γ, C)
235249
end
236250

237251
"""
238252
sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing)
239253
240-
Computes the optimal transport map of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
254+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
241255
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
242256
for the marginals `mu`, `nu` respectively.
243257
@@ -313,35 +327,43 @@ function sinkhorn_unbalanced(
313327
end
314328

315329
"""
316-
sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing)
330+
sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...)
317331
318332
Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
319333
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
320334
for the marginals mu, nu respectively.
321335
322-
See documentation for `sinkhorn_unbalanced` for additional details.
336+
A pre-computed optimal transport `plan` may be provided.
323337
324-
A pre-computed optimal transport `map` may be provided.
338+
See also: [`sinkhorn_unbalanced`](@ref)
325339
"""
326-
function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; map=nothing, kwargs...)
327-
γ = if map === nothing
340+
function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; map=nothing, plan=map, kwargs...)
341+
# check deprecation
342+
if map !== nothing
343+
Base.depwarn(
344+
"the keyword argument `map` is deprecated, please use `plan`",
345+
:sinkhorn_unbalanced2,
346+
)
347+
end
348+
349+
γ = if plan === nothing
328350
sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...)
329351
else
330352
# check dimensions
331353
size(C) == (length(μ), length(ν)) ||
332354
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
333-
size(map) == size(C) || error(
334-
"optimal transport map `map` and cost matrix `C` must be of the same size",
355+
size(plan) == size(C) || error(
356+
"optimal transport plan `plan` and cost matrix `C` must be of the same size",
335357
)
336-
map
358+
plan
337359
end
338360
return dot(γ, C)
339361
end
340362

341363
"""
342364
sinkhorn_stabilized_epsscaling(mu, nu, C, eps; absorb_tol = 1e3, max_iter = 1000, tol = 1e-9, lambda = 0.5, k = 5, verbose = false)
343365
344-
Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C` and entropic regularisation parameter `eps`.
366+
Compute optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularisation parameter `eps`.
345367
Uses stabilized Sinkhorn algorithm with epsilon-scaling (Schmitzer et al., 2019).
346368
347369
`k` epsilon-scaling steps are used with scaling factor `lambda`, i.e. sequentially solve Sinkhorn with regularisation parameters
@@ -382,7 +404,7 @@ end
382404
"""
383405
sinkhorn_stabilized(mu, nu, C, eps; absorb_tol = 1e3, max_iter = 1000, tol = 1e-9, alpha = nothing, beta = nothing, return_duals = false, verbose = false)
384406
385-
Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C` and entropic regularisation parameter `eps`.
407+
Compute optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularisation parameter `eps`.
386408
Uses stabilized Sinkhorn algorithm (Schmitzer et al., 2019).
387409
"""
388410
function sinkhorn_stabilized(
@@ -520,7 +542,7 @@ end
520542
"""
521543
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
522544
523-
Computes the optimal transport map of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
545+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
524546
using the semismooth Newton algorithm [Lorenz 2016].
525547
526548
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.

test/runtests.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ Random.seed!(100)
4444
@test cost pot_cost atol = 1e-5
4545

4646
# ensure that provided map is used
47-
cost2 = emd2(similar(μ), similar(ν), C, lp; map=P)
47+
cost2 = @test_deprecated(emd2(similar(μ), similar(ν), C, lp; map=P))
48+
@test cost2 cost
49+
50+
cost2 = emd2(similar(μ), similar(ν), C, lp; plan=P)
4851
@test cost2 cost
4952
end
5053

@@ -72,7 +75,10 @@ end
7275
@test c c_pot atol = 1e-9
7376

7477
# ensure that provided map is used
75-
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(); map=γ)
78+
c2 = @test_deprecated(sinkhorn2(similar(μ), similar(ν), C, rand(); map=γ))
79+
@test c2 c
80+
81+
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ)
7682
@test c2 c
7783
end
7884

@@ -148,7 +154,12 @@ end
148154
@test c c_pot atol = 1e-9
149155

150156
# ensure that provided map is used
151-
c2 = sinkhorn_unbalanced2(similar(μ), similar(ν), C, rand(), rand(), rand(); map=γ)
157+
c2 = @test_deprecated(
158+
sinkhorn_unbalanced2(similar(μ), similar(ν), C, rand(), rand(), rand(); map=γ)
159+
)
160+
@test c2 c
161+
162+
c2 = sinkhorn_unbalanced2(similar(μ), similar(ν), C, rand(), rand(), rand(); plan=γ)
152163
@test c2 c
153164
end
154165
end

0 commit comments

Comments
 (0)