Skip to content

Commit 161a804

Browse files
committed
remove additional file
1 parent 99efa4e commit 161a804

File tree

2 files changed

+134
-134
lines changed

2 files changed

+134
-134
lines changed

src/OptimalTransport.jl

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ export ot_reg_plan
1818

1919
const MOI = MathOptInterface
2020

21-
include("ot_reg.jl")
22-
2321
function __init__()
2422
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" begin
2523
export POT
@@ -569,4 +567,138 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
569567
end
570568
end
571569

570+
571+
"""
572+
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
573+
574+
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
575+
using the semismooth Newton algorithm [Lorenz 2016].
576+
577+
This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
578+
579+
Parameters:\n
580+
θ: starting Armijo parameter.\n
581+
tol: tolerance of marginal error.\n
582+
maxiter: maximum interation number.\n
583+
κ: control parameter of Armijo.\n
584+
δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
585+
586+
Tips:
587+
If the algorithm does not converge, try some different values of θ.
588+
589+
Reference:
590+
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
591+
"""
592+
function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
593+
if !(sum(mu) sum(nu))
594+
throw(ArgumentError("Error: mu and nu must lie in the simplex"))
595+
end
596+
597+
N = length(mu)
598+
M = length(nu)
599+
600+
# initialize dual potentials as uniforms
601+
a = ones(M) ./ M
602+
b = ones(N) ./ N
603+
γ = spzeros(M, N)
604+
605+
da = spzeros(M)
606+
db = spzeros(N)
607+
608+
converged = false
609+
610+
function DualObjective(a, b)
611+
A = a .* ones(N)' + ones(M) .* b' - C'
612+
613+
return 0.5 * norm(A[A .> 0], 2)^2 - ϵ * (dot(nu, a) + dot(mu, b))
614+
end
615+
616+
# computes minimizing directions, update γ
617+
function search_dir!(a, b, da, db)
618+
P = a * ones(N)' .+ ones(M) * b' .- C'
619+
620+
σ = 1.0 * sparse(P .>= 0)
621+
γ = sparse(max.(P, 0) ./ ϵ)
622+
623+
G = vcat(
624+
hcat(spdiagm(0 => σ * ones(N)), σ),
625+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
626+
)
627+
628+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
629+
630+
x = cg(G + δ * I, -ϵ * h)
631+
632+
da = x[1:M]
633+
return db = x[(M + 1):end]
634+
end
635+
636+
function search_dir(a, b)
637+
P = a * ones(N)' .+ ones(M) * b' .- C'
638+
639+
σ = 1.0 * sparse(P .>= 0)
640+
γ = sparse(max.(P, 0) ./ ϵ)
641+
642+
G = vcat(
643+
hcat(spdiagm(0 => σ * ones(N)), σ),
644+
hcat(sparse'), spdiagm(0 => sparse') * ones(M))),
645+
)
646+
647+
h = vcat* ones(N) - nu, sparse') * ones(M) - mu)
648+
649+
x = cg(G + δ * I, -ϵ * h)
650+
651+
return x[1:M], x[(M + 1):end]
652+
end
653+
654+
# computes optimal maginitude in the minimizing directions
655+
function search_t(a, b, da, db, θ)
656+
d = ϵ * dot(γ, (da .* ones(N)' .+ ones(M) .* db')) - ϵ * (dot(da, nu) + dot(db, mu))
657+
658+
ϕ₀ = DualObjective(a, b)
659+
t = 1
660+
661+
while DualObjective(a + t * da, b + t * db) >= ϕ₀ + t * θ * d
662+
t *= κ
663+
664+
if t < 1e-15
665+
# @warn "@ i = $i, t = $t , armijo did not converge"
666+
break
667+
end
668+
end
669+
return t
670+
end
671+
672+
for i in 1:maxiter
673+
674+
# search_dir!(a, b, da, db)
675+
da, db = search_dir(a, b)
676+
677+
t = search_t(a, b, da, db, θ)
678+
679+
a += t * da
680+
b += t * db
681+
682+
err1 = norm* ones(N) - nu, Inf)
683+
err2 = norm(sparse') * ones(M) - mu, Inf)
684+
685+
if err1 <= tol && err2 <= tol
686+
converged = true
687+
@warn "Converged @ i = $i with marginal errors: \n err1 = $err1, err2 = $err2 \n"
688+
break
689+
elseif i == maxiter
690+
@warn "Not Converged with errors:\n err1 = $err1, err2 = $err2 \n"
691+
end
692+
693+
@debug " t = $t"
694+
@debug "marginal @ i = $i: err1 = $err1, err2 = $err2 "
695+
end
696+
697+
if !converged
698+
@warn "SemiSmooth Newton algorithm did not converge"
699+
end
700+
701+
return sparse')
702+
end
703+
572704
end

src/ot_reg.jl

Lines changed: 0 additions & 132 deletions
This file was deleted.

0 commit comments

Comments
 (0)