Skip to content

Commit ca49b26

Browse files
committed
add test for quadratic OT and POT.smooth_ot_dual
1 parent 3bbc0d5 commit ca49b26

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
11+
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
1112
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
13+
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
1214
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1315

1416
[compat]

src/pot.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,10 @@ function sinkhorn_unbalanced2(
154154
)[1]
155155
end
156156

157+
function smooth_ot_dual(
158+
mu, nu, C, eps, reg_type = "l2", method = "L-BFGS-B", tol = 1e-9, max_iter = 500, verbose = false
159+
)
160+
return pot.smooth.smooth_ot_dual(nu, mu, PyReverseDims(C), eps, reg_type = reg_type, method = method, stopThr = tol, numItermax = max_iter)'
161+
end
162+
157163
end

test/runtests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Distances
55
using PyCall
66
using Tulip
77
using MathOptInterface
8+
using SparseArrays
89

910
using LinearAlgebra
1011
using Random
@@ -171,3 +172,23 @@ end
171172
@test norm- γ_pot, Inf) < 1e-9
172173
end
173174
end
175+
176+
@testset "quadratic optimal transport" begin
177+
M = 250
178+
N = 200
179+
@testset "example" begin
180+
# create two uniform histograms
181+
μ = fill(1 / M, M)
182+
ν = fill(1 / N, N)
183+
184+
# create random cost matrix
185+
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
186+
187+
# compute optimal transport map (Julia implementation + POT)
188+
eps = 0.5
189+
γ = quadreg(μ, ν, C, eps)
190+
γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps))
191+
# need to use a larger tolerance here because of a quirk with the POT solver
192+
@test norm- γ_pot, Inf) < 0.5e-4
193+
end
194+
end

0 commit comments

Comments
 (0)