Skip to content

Commit ada5933

Browse files
Add optimization for discrete uniform distributions of equal size (#17)
* Add optimization for discrete uniform distributions of equal size * Update test/exact.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add optimization for `ot_plan` * Fix test * Add compat entry for FillArrays Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent dcda744 commit ada5933

File tree

5 files changed

+135
-97
lines changed

5 files changed

+135
-97
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "ExactOptimalTransport"
22
uuid = "24df6009-d856-477c-ac5c-91f668376b31"
33
authors = ["JuliaOptimalTransport"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1112
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
@@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1617
[compat]
1718
Distances = "0.9.0, 0.10"
1819
Distributions = "0.24, 0.25"
20+
FillArrays = "0.12"
1921
MathOptInterface = "0.9"
2022
PDMats = "0.10, 0.11"
2123
QuadGK = "2"

src/ExactOptimalTransport.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ExactOptimalTransport
33
using Distances
44
using MathOptInterface
55
using Distributions
6+
using FillArrays
67
using PDMats
78
using QuadGK
89
using StatsBase: StatsBase

src/exact.jl

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,38 @@ a sparse matrix.
263263
See also: [`ot_cost`](@ref), [`emd`](@ref)
264264
"""
265265
function ot_plan(_, μ::DiscreteNonParametric, ν::DiscreteNonParametric)
266-
# unpack the probabilities of the two distributions
266+
# Unpack the probabilities of the two distributions
267+
# Note: support of `DiscreteNonParametric` is sorted
267268
μprobs = probs(μ)
268269
νprobs = probs(ν)
269-
270-
# create the iterator
271-
# note: support of `DiscreteNonParametric` is sorted
272-
iter = Discrete1DOTIterator(μprobs, νprobs)
273-
274-
# create arrays for the indices of the two histograms and the optimal flow between the
275-
# corresponding points
276-
n = length(iter)
277-
I = Vector{Int}(undef, n)
278-
J = Vector{Int}(undef, n)
279-
W = Vector{Base.promote_eltype(μprobs, νprobs)}(undef, n)
280-
281-
# compute the sparse optimal transport plan
282-
@inbounds for (idx, (i, j, w)) in enumerate(iter)
283-
I[idx] = i
284-
J[idx] = j
285-
W[idx] = w
270+
T = Base.promote_eltype(μprobs, νprobs)
271+
272+
return if μprobs isa FillArrays.AbstractFill &&
273+
νprobs isa FillArrays.AbstractFill &&
274+
length(μprobs) == length(νprobs)
275+
# Special case: discrete uniform distributions of the same "size"
276+
k = length(μprobs)
277+
sparse(1:k, 1:k, T(first(μprobs)), k, k)
278+
else
279+
# Generic case
280+
# Create the iterator
281+
iter = Discrete1DOTIterator(μprobs, νprobs)
282+
283+
# create arrays for the indices of the two histograms and the optimal flow between the
284+
# corresponding points
285+
n = length(iter)
286+
I = Vector{Int}(undef, n)
287+
J = Vector{Int}(undef, n)
288+
W = Vector{T}(undef, n)
289+
290+
# compute the sparse optimal transport plan
291+
@inbounds for (idx, (i, j, w)) in enumerate(iter)
292+
I[idx] = i
293+
J[idx] = j
294+
W[idx] = w
295+
end
296+
sparse(I, J, W, length(μprobs), length(νprobs))
286297
end
287-
γ = sparse(I, J, W, length(μprobs), length(νprobs))
288-
289-
return γ
290298
end
291299

292300
"""
@@ -305,45 +313,50 @@ A pre-computed optimal transport `plan` may be provided.
305313
See also: [`ot_plan`](@ref), [`emd2`](@ref)
306314
"""
307315
function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=nothing)
308-
return _ot_cost(c, μ, ν, plan)
309-
end
310-
311-
# compute cost from scratch if no plan is provided
312-
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing)
313-
# unpack the probabilities of the two distributions
316+
# Extract support and probabilities of discrete distributions
317+
# Note: support of `DiscreteNonParametric` is sorted
318+
μsupport = support(μ)
319+
νsupport = support(ν)
314320
μprobs = probs(μ)
315321
νprobs = probs(ν)
316322

323+
return if μprobs isa FillArrays.AbstractFill &&
324+
νprobs isa FillArrays.AbstractFill &&
325+
length(μprobs) == length(νprobs)
326+
# Special case: discrete uniform distributions of the same "size"
327+
# In this case we always just compute `sum(c.(μsupport .- νsupport))` and scale it
328+
# We use pairwise summation and avoid allocations
329+
# (https://github.com/JuliaLang/julia/pull/31020)
330+
T = Base.promote_eltype(μprobs, νprobs)
331+
T(first(μprobs)) *
332+
sum(Broadcast.instantiate(Broadcast.broadcasted(c, μsupport, νsupport)))
333+
else
334+
# Generic case
335+
_ot_cost(c, μsupport, μprobs, νsupport, νprobs, plan)
336+
end
337+
end
338+
339+
# compute cost from scratch if no plan is provided
340+
function _ot_cost(c, μsupport, μprobs, νsupport, νprobs, ::Nothing)
317341
# create the iterator
318-
# note: support of `DiscreteNonParametric` is sorted
319342
iter = Discrete1DOTIterator(μprobs, νprobs)
320343

321344
# compute the cost
322-
μsupport = support(μ)
323-
νsupport = support(ν)
324-
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)
325-
326-
return cost
345+
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)
327346
end
328347

329348
# if a sparse plan is provided, we just iterate through the non-zero entries
330-
function _ot_cost(
331-
c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan::SparseMatrixCSC
332-
)
349+
function _ot_cost(c, μsupport, _, νsupport, _, plan::SparseMatrixCSC)
333350
# extract non-zero flows
334351
I, J, W = findnz(plan)
335352

336353
# compute the cost
337-
μsupport = support(μ)
338-
νsupport = support(ν)
339-
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))
340-
341-
return cost
354+
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))
342355
end
343356

344357
# fallback: compute cost matrix (probably often faster to compute cost from scratch)
345-
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan)
346-
return dot(plan, StatsBase.pairwise(c, support(μ), support(ν)))
358+
function _ot_cost(c, μsupport, _, νsupport, _, plan)
359+
return dot(plan, StatsBase.pairwise(c, μsupport, νsupport))
347360
end
348361

349362
################

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212
"""
1313
discretemeasure(
1414
support::AbstractVector,
15-
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
15+
probs::AbstractVector{<:Real}=FillArrays.Fill(inv(length(support)), length(support)),
1616
)
1717
1818
Construct a finite discrete probability measure with `support` and corresponding
@@ -42,13 +42,13 @@ using KernelFunctions
4242
"""
4343
function discretemeasure(
4444
support::AbstractVector{<:Real},
45-
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
45+
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
4646
)
4747
return DiscreteNonParametric(support, probs)
4848
end
4949
function discretemeasure(
5050
support::AbstractVector,
51-
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
51+
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
5252
)
5353
return FiniteDiscreteMeasure{typeof(support),typeof(probs)}(support, probs)
5454
end

test/exact.jl

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ExactOptimalTransport
22

33
using Distances
4+
using FillArrays
45
using PythonOT: PythonOT
56
using Tulip
67
using MathOptInterface
@@ -110,56 +111,77 @@ Random.seed!(100)
110111
end
111112

112113
@testset "discrete case" begin
113-
# random source and target marginal
114-
m = 30
115-
μprobs = normalize!(rand(m), 1)
116-
μsupport = randn(m)
117-
μ = DiscreteNonParametric(μsupport, μprobs)
118-
119-
n = 50
120-
νprobs = normalize!(rand(n), 1)
121-
νsupport = randn(n)
122-
ν = DiscreteNonParametric(νsupport, νprobs)
123-
124-
# compute OT plan
125-
γ = @inferred(ot_plan(euclidean, μ, ν))
126-
@test γ isa SparseMatrixCSC
127-
@test size(γ) == (m, n)
128-
@test vec(sum(γ; dims=2)) μ.p
129-
@test vec(sum(γ; dims=1)) ν.p
130-
131-
# consistency checks
132-
I, J, W = findnz(γ)
133-
@test all(w > zero(w) for w in W)
134-
@test sum(W) 1
135-
@test sort(unique(I)) == 1:m
136-
@test sort(unique(J)) == 1:n
137-
@test sort(I .+ J) == 2:(m + n)
138-
139-
# compute OT cost
140-
c = @inferred(ot_cost(euclidean, μ, ν))
141-
142-
# compare with computation with explicit cost matrix
143-
# DiscreteNonParametric sorts the support automatically, here we have to sort
144-
# manually
145-
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
146-
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
147-
@test c2 c rtol = 1e-5
148-
149-
# compare with POT
150-
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
151-
# error
152-
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
153-
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
154-
155-
# do not use the probabilities of μ and ν to ensure that the provided plan is
156-
# used
157-
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
158-
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
159-
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
160-
@test c2 c
161-
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
162-
@test c2 c
114+
# different random sources and target marginals:
115+
# non-uniform + different size, uniform + different size, uniform + equal size
116+
for (μ, ν) in (
117+
(
118+
DiscreteNonParametric(randn(30), normalize!(rand(30), 1)),
119+
DiscreteNonParametric(randn(50), normalize!(rand(50), 1)),
120+
),
121+
(
122+
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
123+
DiscreteNonParametric(randn(50), Fill(1 / 50, 50)),
124+
),
125+
(
126+
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
127+
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
128+
),
129+
)
130+
# extract support, probabilities, and "size"
131+
μsupport = support(μ)
132+
μprobs = probs(μ)
133+
m = length(μprobs)
134+
135+
νsupport = support(ν)
136+
νprobs = probs(ν)
137+
n = length(νprobs)
138+
139+
# compute OT plan
140+
γ = @inferred(ot_plan(euclidean, μ, ν))
141+
@test γ isa SparseMatrixCSC
142+
@test size(γ) == (m, n)
143+
@test vec(sum(γ; dims=2)) μ.p
144+
@test vec(sum(γ; dims=1)) ν.p
145+
146+
# consistency checks
147+
I, J, W = findnz(γ)
148+
@test all(w > zero(w) for w in W)
149+
@test sum(W) 1
150+
@test sort(unique(I)) == 1:m
151+
@test sort(unique(J)) == 1:n
152+
@test sort(I .+ J) == if μprobs isa Fill && νprobs isa Fill && m == n
153+
# Optimized version for special case (discrete uniform + equal size)
154+
2:2:(m + n)
155+
else
156+
# Generic case (not optimized)
157+
2:(m + n)
158+
end
159+
160+
# compute OT cost
161+
c = @inferred(ot_cost(euclidean, μ, ν))
162+
163+
# compare with computation with explicit cost matrix
164+
# DiscreteNonParametric sorts the support automatically, here we have to sort
165+
# manually
166+
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
167+
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
168+
@test c2 c rtol = 1e-5
169+
170+
# compare with POT
171+
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
172+
# error
173+
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
174+
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
175+
176+
# do not use the probabilities of μ and ν to ensure that the provided plan is
177+
# used
178+
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
179+
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
180+
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
181+
@test c2 c
182+
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
183+
@test c2 c
184+
end
163185
end
164186
end
165187

0 commit comments

Comments
 (0)