Skip to content

Commit 0907838

Browse files
authored
Fix emd with integer-valued cost matrices (#72)
1 parent 8e709bc commit 0907838

File tree

3 files changed

+42
-26
lines changed

3 files changed

+42
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ function emd(μ, ν, C, model::MOI.ModelLike)
4343
xmat = reshape(x, nμ, nν)
4444

4545
# define objective function
46-
T = eltype(C)
46+
T = float(eltype(C))
4747
zero_T = zero(T)
4848
MOI.set(
4949
model,
5050
MOI.ObjectiveFunction{MOI.ScalarAffineFunction{T}}(),
51-
MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(vec(C), x), zero_T),
51+
MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(float.(vec(C)), x), zero_T),
5252
)
5353
MOI.set(model, MOI.ObjectiveSense(), MOI.MIN_SENSE)
5454

test/runtests.jl

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,45 @@ Random.seed!(100)
2323
μ ./= sum(μ)
2424
ν ./= sum(ν)
2525

26-
# create random cost matrix
27-
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
28-
29-
# compute optimal transport map and cost with POT
30-
pot_P = POT.emd(μ, ν, C)
31-
pot_cost = POT.emd2(μ, ν, C)
32-
33-
# compute optimal transport map and cost with Tulip
34-
lp = Tulip.Optimizer()
35-
P = emd(μ, ν, C, lp)
36-
@test size(C) == size(P)
37-
@test MOI.get(lp, MOI.TerminationStatus()) == MOI.OPTIMAL
38-
@test maximum(abs, P .- pot_P) < 1e-2
39-
40-
lp = Tulip.Optimizer()
41-
cost = emd2(μ, ν, C, lp)
42-
@test dot(C, P) cost atol = 1e-5
43-
@test MOI.get(lp, MOI.TerminationStatus()) == MOI.OPTIMAL
44-
@test cost pot_cost atol = 1e-5
45-
46-
# ensure that provided map is used
47-
cost2 = emd2(similar(μ), similar(ν), C, lp; plan=P)
48-
@test cost2 cost
26+
@testset "example" begin
27+
# create random cost matrix
28+
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
29+
30+
# compute optimal transport map and cost with POT
31+
pot_P = POT.emd(μ, ν, C)
32+
pot_cost = POT.emd2(μ, ν, C)
33+
34+
# compute optimal transport map and cost with Tulip
35+
lp = Tulip.Optimizer()
36+
P = emd(μ, ν, C, lp)
37+
@test size(C) == size(P)
38+
@test MOI.get(lp, MOI.TerminationStatus()) == MOI.OPTIMAL
39+
@test maximum(abs, P .- pot_P) < 1e-2
40+
41+
lp = Tulip.Optimizer()
42+
cost = emd2(μ, ν, C, lp)
43+
@test dot(C, P) cost atol = 1e-5
44+
@test MOI.get(lp, MOI.TerminationStatus()) == MOI.OPTIMAL
45+
@test cost pot_cost atol = 1e-5
46+
end
47+
48+
@testset "pre-computed plan" begin
49+
# create random cost matrix
50+
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
51+
52+
# compute optimal transport map
53+
P = emd(μ, ν, C, Tulip.Optimizer())
54+
55+
# do not use μ and ν to ensure that provided map is used
56+
cost = emd2(similar(μ), similar(ν), C, Tulip.Optimizer(); plan=P)
57+
@test cost emd2(μ, ν, C, Tulip.Optimizer())
58+
end
59+
60+
# https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/71
61+
@testset "cost matrix with integers" begin
62+
C = pairwise(SqEuclidean(), rand(1:10, 1, M), rand(1:10, 1, N); dims=2)
63+
emd2(μ, ν, C, Tulip.Optimizer())
64+
end
4965
end
5066

5167
@testset "entropically regularized transport" begin

0 commit comments

Comments
 (0)