Skip to content

Commit 42a31ac

Browse files
committed
added ot_reg
1 parent 18bf8c8 commit 42a31ac

File tree

3 files changed

+139
-132
lines changed

3 files changed

+139
-132
lines changed

src/OptimalTransport.jl

Lines changed: 6 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2
1313
export emd, emd2
1414
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1515
export sinkhorn_unbalanced, sinkhorn_unbalanced2
16-
export quadreg
16+
export ot_reg_plan
1717

1818
const MOI = MathOptInterface
1919

@@ -506,137 +506,12 @@ function sinkhorn_barycenter(
506506
return u_all[1, :] .* (K_all[1] * v_all[1, :])
507507
end
508508

509-
"""
510-
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
511-
512-
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
513-
using the semismooth Newton algorithm [Lorenz 2016].
514-
515-
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
516-
517-
Parameters:\n
518-
θ: starting Armijo parameter.\n
519-
tol: tolerance of marginal error.\n
520-
maxiter: maximum interation number.\n
521-
κ: control parameter of Armijo.\n
522-
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
523-
524-
Tips:
525-
If the algorithm does not converge, try some different values of θ.
526-
527-
Reference:
528-
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
529-
"""
530-
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
531-
if !(sum(mu) sum(nu))
532-
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
533-
end
534-
535-
N = length(mu)
536-
M = length(nu)
537-
538-
# initialize dual potentials as uniforms
539-
a = ones(M) ./ M
540-
b = ones(N) ./ N
541-
γ = spzeros(M, N)
542-
543-
da = spzeros(M)
544-
db = spzeros(N)
545-
546-
converged = false
547-
548-
function DualObjective(a, b)
549-
A = a .* ones(N)' + ones(M) .* b' - C'
550-
551-
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
552-
end
553-
554-
# computes minimizing directions, update γ
555-
function search_dir!(a, b, da, db)
556-
P = a * ones(N)' .+ ones(M) * b' .- C'
557-
558-
σ = 1.0 * sparse(P .>= 0)
559-
γ = sparse(max.(P, 0) ./ ϵ)
560-
561-
G = vcat(
562-
hcat(spdiagm(0 => σ * ones(N)), σ),
563-
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
564-
)
565-
566-
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
567-
568-
x = cg(G + δ * I, -ϵ * h)
569-
570-
da = x[1:M]
571-
return db = x[(M + 1):end]
572-
end
573-
574-
function search_dir(a, b)
575-
P = a * ones(N)' .+ ones(M) * b' .- C'
576-
577-
σ = 1.0 * sparse(P .>= 0)
578-
γ = sparse(max.(P, 0) ./ ϵ)
579-
580-
G = vcat(
581-
hcat(spdiagm(0 => σ * ones(N)), σ),
582-
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
583-
)
584-
585-
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
586-
587-
x = cg(G + δ * I, -ϵ * h)
588-
589-
return x[1:M], x[(M + 1):end]
590-
end
591-
592-
# computes optimal maginitude in the minimizing directions
593-
function search_t(a, b, da, db, θ)
594-
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
595-
596-
ϕ₀ = DualObjective(a, b)
597-
t = 1
598-
599-
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
600-
t *= κ
601-
602-
if t < 1e-15
603-
# @warn "@ i = $i, t = $t , armijo did not converge"
604-
break
605-
end
606-
end
607-
return t
608-
end
609-
610-
for i in 1:maxiter
611-
612-
# search_dir!(a, b, da, db)
613-
da, db = search_dir(a, b)
614-
615-
t = search_t(a, b, da, db, θ)
616-
617-
a += t * da
618-
b += t * db
619-
620-
err1 = norm* ones(N) - nu, Inf)
621-
err2 = norm(sparse') * ones(M) - mu, Inf)
622-
623-
if err1 <= tol && err2 <= tol
624-
converged = true
625-
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
626-
break
627-
elseif i == maxiter
628-
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
629-
end
630-
631-
@debug " t = $t"
632-
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
633-
end
634-
635-
if !converged
636-
@warn "SemiSmooth Newton algorithm did not converge"
509+
function ot_reg_plan(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)
510+
if (reg_func == "L2") && (method == "lorenz")
511+
return quadreg(mu, nu, C, eps; kwargs...)
512+
else
513+
@warn "Unimplemented"
637514
end
638-
639-
return sparse')
640515
end
641516

642517
end

src/ot_reg.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
3+
4+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
5+
using the semismooth Newton algorithm [Lorenz 2016].
6+
7+
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
8+
9+
Parameters:\n
10+
θ: starting Armijo parameter.\n
11+
tol: tolerance of marginal error.\n
12+
maxiter: maximum interation number.\n
13+
κ: control parameter of Armijo.\n
14+
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
15+
16+
Tips:
17+
If the algorithm does not converge, try some different values of θ.
18+
19+
Reference:
20+
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
21+
"""
22+
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
23+
if !(sum(mu) sum(nu))
24+
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
25+
end
26+
27+
N = length(mu)
28+
M = length(nu)
29+
30+
# initialize dual potentials as uniforms
31+
a = ones(M) ./ M
32+
b = ones(N) ./ N
33+
γ = spzeros(M, N)
34+
35+
da = spzeros(M)
36+
db = spzeros(N)
37+
38+
converged = false
39+
40+
function DualObjective(a, b)
41+
A = a .* ones(N)' + ones(M) .* b' - C'
42+
43+
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
44+
end
45+
46+
# computes minimizing directions, update γ
47+
function search_dir!(a, b, da, db)
48+
P = a * ones(N)' .+ ones(M) * b' .- C'
49+
50+
σ = 1.0 * sparse(P .>= 0)
51+
γ = sparse(max.(P, 0) ./ ϵ)
52+
53+
G = vcat(
54+
hcat(spdiagm(0 => σ * ones(N)), σ),
55+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
56+
)
57+
58+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
59+
60+
x = cg(G + δ * I, -ϵ * h)
61+
62+
da = x[1:M]
63+
return db = x[(M + 1):end]
64+
end
65+
66+
function search_dir(a, b)
67+
P = a * ones(N)' .+ ones(M) * b' .- C'
68+
69+
σ = 1.0 * sparse(P .>= 0)
70+
γ = sparse(max.(P, 0) ./ ϵ)
71+
72+
G = vcat(
73+
hcat(spdiagm(0 => σ * ones(N)), σ),
74+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
75+
)
76+
77+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
78+
79+
x = cg(G + δ * I, -ϵ * h)
80+
81+
return x[1:M], x[(M + 1):end]
82+
end
83+
84+
# computes optimal maginitude in the minimizing directions
85+
function search_t(a, b, da, db, θ)
86+
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
87+
88+
ϕ₀ = DualObjective(a, b)
89+
t = 1
90+
91+
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
92+
t *= κ
93+
94+
if t < 1e-15
95+
# @warn "@ i = $i, t = $t , armijo did not converge"
96+
break
97+
end
98+
end
99+
return t
100+
end
101+
102+
for i in 1:maxiter
103+
104+
# search_dir!(a, b, da, db)
105+
da, db = search_dir(a, b)
106+
107+
t = search_t(a, b, da, db, θ)
108+
109+
a += t * da
110+
b += t * db
111+
112+
err1 = norm* ones(N) - nu, Inf)
113+
err2 = norm(sparse') * ones(M) - mu, Inf)
114+
115+
if err1 <= tol && err2 <= tol
116+
converged = true
117+
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
118+
break
119+
elseif i == maxiter
120+
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
121+
end
122+
123+
@debug " t = $t"
124+
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
125+
end
126+
127+
if !converged
128+
@warn "SemiSmooth Newton algorithm did not converge"
129+
end
130+
131+
return sparse')
132+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ end
165165

166166
# compute optimal transport map (Julia implementation + POT)
167167
eps = 0.25
168-
γ = quadreg(μ, ν, C, eps)
168+
γ = ot_reg_plan(μ, ν, C, eps)
169169
γ_pot = POT.Smooth.smooth_ot_dual(μ, ν, C, eps; stopThr=1e-9)
170170
# need to use a larger tolerance here because of a quirk with the POT solver
171171
@test norm- γ_pot, Inf) < 1e-4

0 commit comments

Comments
 (0)