Skip to content

Commit 63d0eef

Browse files
1d Optimal Transport (#45)
Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 8f12f6c commit 63d0eef

File tree

4 files changed

+114
-0
lines changed

4 files changed

+114
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@ version = "0.3.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1112
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
13+
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1214
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1315

1416
[compat]
1517
Distances = "0.9.0, 0.10"
18+
Distributions = "0.25"
1619
IterativeSolvers = "0.8.4, 0.9"
1720
LogExpFunctions = "0.2"
1821
MathOptInterface = "0.9"
22+
QuadGK = "2"
1923
julia = "1"
2024

2125
[extras]

docs/src/index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33

44
## Exact optimal transport (Kantorovich) problem
5+
56
```@docs
67
emd
78
emd2
9+
ot_plan
10+
ot_cost
811
```
912

1013
## Entropically regularised optimal transport

src/OptimalTransport.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@ using LinearAlgebra
99
using IterativeSolvers, SparseArrays
1010
using LogExpFunctions: LogExpFunctions
1111
using MathOptInterface
12+
using Distributions
13+
using QuadGK
1214

1315
export sinkhorn, sinkhorn2
1416
export emd, emd2
1517
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1618
export sinkhorn_unbalanced, sinkhorn_unbalanced2
1719
export quadreg
20+
export ot_cost, ot_plan
1821

1922
const MOI = MathOptInterface
2023

@@ -649,4 +652,62 @@ function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
649652
return sparse')
650653
end
651654

655+
"""
656+
ot_cost(
657+
c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution; plan=nothing
658+
)
659+
660+
Compute the optimal transport cost for the Monge-Kantorovich problem with univariate
661+
distributions `μ` and `ν` as source and target marginals and cost function `c` of
662+
the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.
663+
664+
In this setting, the optimal transport cost can be computed as
665+
```math
666+
\\int_0^1 c(F_\\mu^{-1}(x), F_\\nu^{-1}(x)) \\mathrm{d}x
667+
```
668+
where ``F_\\mu^{-1}`` and ``F_\\nu^{-1}`` are the quantile functions of `μ` and `ν`,
669+
respectively.
670+
671+
A pre-computed optimal transport `plan` may be provided.
672+
673+
See also: [`ot_plan`](@ref), [`emd2`](@ref)
674+
"""
675+
function ot_cost(
676+
c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution; plan=nothing
677+
)
678+
cost, _ = if plan === nothing
679+
quadgk(0, 1) do q
680+
return c(quantile(μ, q), quantile(ν, q))
681+
end
682+
else
683+
quadgk(0, 1) do q
684+
x = quantile(μ, q)
685+
return c(x, plan(x))
686+
end
687+
end
688+
return cost
689+
end
690+
691+
"""
692+
ot_plan(c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution)
693+
694+
Compute the optimal transport plan for the Monge-Kantorovich problem with univariate
695+
distributions `μ` and `ν` as source and target marginals and cost function `c` of
696+
the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.
697+
698+
In this setting, the optimal transport plan is the Monge map
699+
```math
700+
T = F_\\nu^{-1} \\circ F_\\mu
701+
```
702+
where ``F_\\mu`` is the cumulative distribution function of `μ` and ``F_\\nu^{-1}`` is the
703+
quantile function of `ν`.
704+
705+
See also: [`ot_cost`](@ref), [`emd`](@ref)
706+
"""
707+
function ot_plan(c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution)
708+
# Use T instead of γ to indicate that this is a Monge map.
709+
T(x) = quantile(ν, cdf(μ, x))
710+
return T
711+
end
712+
652713
end

test/runtests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Distances
44
using PythonOT: PythonOT
55
using Tulip
66
using MathOptInterface
7+
using Distributions
78
using SparseArrays
89

910
using LinearAlgebra
@@ -64,6 +65,51 @@ Random.seed!(100)
6465
end
6566
end
6667

68+
@testset "1D Optimal Transport for Convex Cost" begin
69+
@testset "continuous distributions" begin
70+
# two normal distributions (has analytical solution)
71+
μ = Normal(randn(), rand())
72+
ν = Normal(randn(), rand())
73+
74+
# compute OT plan
75+
γ = ot_plan(sqeuclidean, μ, ν)
76+
x = randn()
77+
@test γ(x) quantile(ν, cdf(μ, x))
78+
79+
# compute OT cost
80+
c = ot_cost(sqeuclidean, μ, ν)
81+
@test c (mean(μ) - mean(ν))^2 + (std(μ) - std(ν))^2
82+
83+
# do not use ν to ensure that the provided plan is used
84+
@test ot_cost(sqeuclidean, μ, Normal(randn(), rand()); plan=γ) c
85+
end
86+
87+
@testset "semidiscrete case" begin
88+
μ = Normal(randn(), rand())
89+
νprobs = rand(30)
90+
νprobs ./= sum(νprobs)
91+
ν = Categorical(νprobs)
92+
93+
# compute OT plan
94+
γ = ot_plan(euclidean, μ, ν)
95+
x = randn()
96+
@test γ(x) quantile(ν, cdf(μ, x))
97+
98+
# compute OT cost, without and with provided plan
99+
# do not use ν in the second case to ensure that the provided plan is used
100+
c = ot_cost(euclidean, μ, ν)
101+
@test ot_cost(euclidean, μ, Categorical(reverse(νprobs)); plan=γ) c
102+
103+
# check that OT cost is consistent with OT cost of a discretization
104+
m = 500
105+
xs = rand(μ, m)
106+
μdiscrete = fill(1 / m, m)
107+
C = pairwise(Euclidean(), xs', (1:length(νprobs))'; dims=2)
108+
c2 = emd2(μdiscrete, νprobs, C, Tulip.Optimizer())
109+
@test c2 c rtol = 1e-1
110+
end
111+
end
112+
67113
@testset "entropically regularized transport" begin
68114
M = 250
69115
N = 200

0 commit comments

Comments
 (0)