Skip to content

Commit 72e3e6d

Browse files
authored
Merge pull request #56 from JuliaOptimalTransport/test_barycenter
Tests for sinkhorn_barycenter
2 parents 369d1ce + 4f1851e commit 72e3e6d

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/pot.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,26 @@ function smooth_ot_dual(
169169
)'
170170
end
171171

172+
function barycenter(
173+
mu_all,
174+
C,
175+
eps;
176+
weights=nothing,
177+
method="sinkhorn",
178+
max_iter=10000,
179+
tol=0.0001,
180+
verbose=false,
181+
)
182+
return pot.barycenter(
183+
mu_all',
184+
C,
185+
eps;
186+
weights=weights,
187+
method=method,
188+
numItermax=max_iter,
189+
stopThr=tol,
190+
verbose=verbose,
191+
)
192+
end
193+
172194
end

test/runtests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,23 @@ end
192192
@test norm- γ_pot, Inf) < 1e-4
193193
end
194194
end
195+
196+
@testset "sinkhorn barycenter" begin
197+
@testset "example" begin
198+
# set up support
199+
support = range(-1, 1; length=250)
200+
μ1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
201+
μ1 ./= sum(μ1)
202+
μ2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
203+
μ2 ./= sum(μ2)
204+
μ_all = hcat(μ1, μ2)'
205+
# create cost matrix
206+
C = pairwise(SqEuclidean(), support')
207+
# compute Sinkhorn barycenter (Julia implementation + POT)
208+
eps = 0.01
209+
μ_interp = sinkhorn_barycenter(μ_all, [C, C], eps, [0.5, 0.5])
210+
μ_interp_pot = POT.barycenter(μ_all, C, eps; weights=[0.5, 0.5])
211+
# need to use a larger tolerance here because of a quirk with the POT solver
212+
@test norm(μ_interp - μ_interp_pot, Inf) < 1e-9
213+
end
214+
end

0 commit comments

Comments
 (0)