Skip to content

Commit 5ef3170

Browse files
committed
added ot_reg
1 parent 6acc912 commit 5ef3170

File tree

4 files changed

+142
-133
lines changed

4 files changed

+142
-133
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 8 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ export sinkhorn, sinkhorn2
1414
export emd, emd2
1515
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1616
export sinkhorn_unbalanced, sinkhorn_unbalanced2
17-
export quadreg
17+
export ot_reg_plan
1818

1919
const MOI = MathOptInterface
2020

21+
include("ot_reg.jl")
22+
2123
function __init__()
2224
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" begin
2325
export POT
@@ -539,137 +541,12 @@ function sinkhorn_barycenter(
539541
return u_all[1, :] .* (K_all[1] * v_all[1, :])
540542
end
541543

542-
"""
543-
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
544-
545-
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
546-
using the semismooth Newton algorithm [Lorenz 2016].
547-
548-
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
549-
550-
Parameters:\n
551-
θ: starting Armijo parameter.\n
552-
tol: tolerance of marginal error.\n
553-
maxiter: maximum interation number.\n
554-
κ: control parameter of Armijo.\n
555-
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
556-
557-
Tips:
558-
If the algorithm does not converge, try some different values of θ.
559-
560-
Reference:
561-
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
562-
"""
563-
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
564-
if !(sum(mu) sum(nu))
565-
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
566-
end
567-
568-
N = length(mu)
569-
M = length(nu)
570-
571-
# initialize dual potentials as uniforms
572-
a = ones(M) ./ M
573-
b = ones(N) ./ N
574-
γ = spzeros(M, N)
575-
576-
da = spzeros(M)
577-
db = spzeros(N)
578-
579-
converged = false
580-
581-
function DualObjective(a, b)
582-
A = a .* ones(N)' + ones(M) .* b' - C'
583-
584-
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
585-
end
586-
587-
# computes minimizing directions, update γ
588-
function search_dir!(a, b, da, db)
589-
P = a * ones(N)' .+ ones(M) * b' .- C'
590-
591-
σ = 1.0 * sparse(P .>= 0)
592-
γ = sparse(max.(P, 0) ./ ϵ)
593-
594-
G = vcat(
595-
hcat(spdiagm(0 => σ * ones(N)), σ),
596-
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
597-
)
598-
599-
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
600-
601-
x = cg(G + δ * I, -ϵ * h)
602-
603-
da = x[1:M]
604-
return db = x[(M + 1):end]
605-
end
606-
607-
function search_dir(a, b)
608-
P = a * ones(N)' .+ ones(M) * b' .- C'
609-
610-
σ = 1.0 * sparse(P .>= 0)
611-
γ = sparse(max.(P, 0) ./ ϵ)
612-
613-
G = vcat(
614-
hcat(spdiagm(0 => σ * ones(N)), σ),
615-
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
616-
)
617-
618-
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
619-
620-
x = cg(G + δ * I, -ϵ * h)
621-
622-
return x[1:M], x[(M + 1):end]
623-
end
624-
625-
# computes optimal maginitude in the minimizing directions
626-
function search_t(a, b, da, db, θ)
627-
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
628-
629-
ϕ₀ = DualObjective(a, b)
630-
t = 1
631-
632-
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
633-
t *= κ
634-
635-
if t < 1e-15
636-
# @warn "@ i = $i, t = $t , armijo did not converge"
637-
break
638-
end
639-
end
640-
return t
641-
end
642-
643-
for i in 1:maxiter
644-
645-
# search_dir!(a, b, da, db)
646-
da, db = search_dir(a, b)
647-
648-
t = search_t(a, b, da, db, θ)
649-
650-
a += t * da
651-
b += t * db
652-
653-
err1 = norm* ones(N) - nu, Inf)
654-
err2 = norm(sparse') * ones(M) - mu, Inf)
655-
656-
if err1 <= tol && err2 <= tol
657-
converged = true
658-
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
659-
break
660-
elseif i == maxiter
661-
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
662-
end
663-
664-
@debug " t = $t"
665-
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
666-
end
667-
668-
if !converged
669-
@warn "SemiSmooth Newton algorithm did not converge"
544+
function ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)
545+
if (reg_func == "L2") && (method == "lorenz")
546+
return quadreg(mu, nu, C, eps; kwargs...)
547+
else
548+
@warn "Unimplemented"
670549
end
671-
672-
return sparse')
673550
end
674551

675552
end

src/ot_reg.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
3+
4+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
5+
using the semismooth Newton algorithm [Lorenz 2016].
6+
7+
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
8+
9+
Parameters:\n
10+
θ: starting Armijo parameter.\n
11+
tol: tolerance of marginal error.\n
12+
maxiter: maximum interation number.\n
13+
κ: control parameter of Armijo.\n
14+
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
15+
16+
Tips:
17+
If the algorithm does not converge, try some different values of θ.
18+
19+
Reference:
20+
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
21+
"""
22+
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
23+
if !(sum(mu) sum(nu))
24+
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
25+
end
26+
27+
N = length(mu)
28+
M = length(nu)
29+
30+
# initialize dual potentials as uniforms
31+
a = ones(M) ./ M
32+
b = ones(N) ./ N
33+
γ = spzeros(M, N)
34+
35+
da = spzeros(M)
36+
db = spzeros(N)
37+
38+
converged = false
39+
40+
function DualObjective(a, b)
41+
A = a .* ones(N)' + ones(M) .* b' - C'
42+
43+
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
44+
end
45+
46+
# computes minimizing directions, update γ
47+
function search_dir!(a, b, da, db)
48+
P = a * ones(N)' .+ ones(M) * b' .- C'
49+
50+
σ = 1.0 * sparse(P .>= 0)
51+
γ = sparse(max.(P, 0) ./ ϵ)
52+
53+
G = vcat(
54+
hcat(spdiagm(0 => σ * ones(N)), σ),
55+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
56+
)
57+
58+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
59+
60+
x = cg(G + δ * I, -ϵ * h)
61+
62+
da = x[1:M]
63+
return db = x[(M + 1):end]
64+
end
65+
66+
function search_dir(a, b)
67+
P = a * ones(N)' .+ ones(M) * b' .- C'
68+
69+
σ = 1.0 * sparse(P .>= 0)
70+
γ = sparse(max.(P, 0) ./ ϵ)
71+
72+
G = vcat(
73+
hcat(spdiagm(0 => σ * ones(N)), σ),
74+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
75+
)
76+
77+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
78+
79+
x = cg(G + δ * I, -ϵ * h)
80+
81+
return x[1:M], x[(M + 1):end]
82+
end
83+
84+
# computes optimal maginitude in the minimizing directions
85+
function search_t(a, b, da, db, θ)
86+
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
87+
88+
ϕ₀ = DualObjective(a, b)
89+
t = 1
90+
91+
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
92+
t *= κ
93+
94+
if t < 1e-15
95+
# @warn "@ i = $i, t = $t , armijo did not converge"
96+
break
97+
end
98+
end
99+
return t
100+
end
101+
102+
for i in 1:maxiter
103+
104+
# search_dir!(a, b, da, db)
105+
da, db = search_dir(a, b)
106+
107+
t = search_t(a, b, da, db, θ)
108+
109+
a += t * da
110+
b += t * db
111+
112+
err1 = norm* ones(N) - nu, Inf)
113+
err2 = norm(sparse') * ones(M) - mu, Inf)
114+
115+
if err1 <= tol && err2 <= tol
116+
converged = true
117+
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
118+
break
119+
elseif i == maxiter
120+
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
121+
end
122+
123+
@debug " t = $t"
124+
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
125+
end
126+
127+
if !converged
128+
@warn "SemiSmooth Newton algorithm did not converge"
129+
end
130+
131+
return sparse')
132+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ end
197197

198198
# compute optimal transport map (Julia implementation + POT)
199199
eps = 0.25
200-
γ = quadreg(μ, ν, C, eps)
200+
γ = ot_reg_plan(μ, ν, C, eps)
201201
γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps; max_iter=5000))
202202
# need to use a larger tolerance here because of a quirk with the POT solver
203203
@test norm- γ_pot, Inf) < 1e-4

0 commit comments

Comments
 (0)