Skip to content

Commit 3e8a3c8

Browse files
committed
remove additional file
1 parent 4e758f8 commit 3e8a3c8

File tree

2 files changed

+134
-132
lines changed

2 files changed

+134
-132
lines changed

src/OptimalTransport.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,4 +534,138 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
534534
end
535535
end
536536

537+
538+
"""
539+
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
540+
541+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
542+
using the semismooth Newton algorithm [Lorenz 2016].
543+
544+
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
545+
546+
Parameters:\n
547+
θ: starting Armijo parameter.\n
548+
tol: tolerance of marginal error.\n
549+
maxiter: maximum interation number.\n
550+
κ: control parameter of Armijo.\n
551+
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
552+
553+
Tips:
554+
If the algorithm does not converge, try some different values of θ.
555+
556+
Reference:
557+
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
558+
"""
559+
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
560+
if !(sum(mu) sum(nu))
561+
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
562+
end
563+
564+
N = length(mu)
565+
M = length(nu)
566+
567+
# initialize dual potentials as uniforms
568+
a = ones(M) ./ M
569+
b = ones(N) ./ N
570+
γ = spzeros(M, N)
571+
572+
da = spzeros(M)
573+
db = spzeros(N)
574+
575+
converged = false
576+
577+
function DualObjective(a, b)
578+
A = a .* ones(N)' + ones(M) .* b' - C'
579+
580+
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
581+
end
582+
583+
# computes minimizing directions, update γ
584+
function search_dir!(a, b, da, db)
585+
P = a * ones(N)' .+ ones(M) * b' .- C'
586+
587+
σ = 1.0 * sparse(P .>= 0)
588+
γ = sparse(max.(P, 0) ./ ϵ)
589+
590+
G = vcat(
591+
hcat(spdiagm(0 => σ * ones(N)), σ),
592+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
593+
)
594+
595+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
596+
597+
x = cg(G + δ * I, -ϵ * h)
598+
599+
da = x[1:M]
600+
return db = x[(M + 1):end]
601+
end
602+
603+
function search_dir(a, b)
604+
P = a * ones(N)' .+ ones(M) * b' .- C'
605+
606+
σ = 1.0 * sparse(P .>= 0)
607+
γ = sparse(max.(P, 0) ./ ϵ)
608+
609+
G = vcat(
610+
hcat(spdiagm(0 => σ * ones(N)), σ),
611+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
612+
)
613+
614+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
615+
616+
x = cg(G + δ * I, -ϵ * h)
617+
618+
return x[1:M], x[(M + 1):end]
619+
end
620+
621+
# computes optimal maginitude in the minimizing directions
622+
function search_t(a, b, da, db, θ)
623+
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
624+
625+
ϕ₀ = DualObjective(a, b)
626+
t = 1
627+
628+
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
629+
t *= κ
630+
631+
if t < 1e-15
632+
# @warn "@ i = $i, t = $t , armijo did not converge"
633+
break
634+
end
635+
end
636+
return t
637+
end
638+
639+
for i in 1:maxiter
640+
641+
# search_dir!(a, b, da, db)
642+
da, db = search_dir(a, b)
643+
644+
t = search_t(a, b, da, db, θ)
645+
646+
a += t * da
647+
b += t * db
648+
649+
err1 = norm* ones(N) - nu, Inf)
650+
err2 = norm(sparse') * ones(M) - mu, Inf)
651+
652+
if err1 <= tol && err2 <= tol
653+
converged = true
654+
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
655+
break
656+
elseif i == maxiter
657+
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
658+
end
659+
660+
@debug " t = $t"
661+
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
662+
end
663+
664+
if !converged
665+
@warn "SemiSmooth Newton algorithm did not converge"
666+
end
667+
668+
return sparse')
669+
end
670+
537671
end

src/ot_reg.jl

Lines changed: 0 additions & 132 deletions
This file was deleted.

0 commit comments

Comments
 (0)