@@ -18,8 +18,6 @@ export ot_reg_plan
1818
1919const MOI = MathOptInterface
2020
21- include (" ot_reg.jl" )
22-
2321function __init__ ()
2422 @require PyCall = " 438e738f-606a-5dbb-bf0a-cddfbfd45ab0" begin
2523 export POT
@@ -569,4 +567,138 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
569567 end
570568end
571569
570+
571+ """
572+ quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
573+
574+ Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
575+ using the semismooth Newton algorithm [Lorenz 2016].
576+
577+ This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
578+
579+ Parameters:\n
580+ θ: starting Armijo parameter.\n
581+ tol: tolerance of marginal error.\n
582+ maxiter: maximum interation number.\n
583+ κ: control parameter of Armijo.\n
584+ δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
585+
586+ Tips:
587+ If the algorithm does not converge, try some different values of θ.
588+
589+ Reference:
590+ Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
591+ """
592+ function quadreg (mu, nu, C, ϵ; θ= 0.1 , tol= 1e-5 , maxiter= 50 , κ= 0.5 , δ= 1e-5 )
593+ if ! (sum (mu) ≈ sum (nu))
594+ throw (ArgumentError (" Error: mu and nu must lie in the simplex" ))
595+ end
596+
597+ N = length (mu)
598+ M = length (nu)
599+
600+ # initialize dual potentials as uniforms
601+ a = ones (M) ./ M
602+ b = ones (N) ./ N
603+ γ = spzeros (M, N)
604+
605+ da = spzeros (M)
606+ db = spzeros (N)
607+
608+ converged = false
609+
610+ function DualObjective (a, b)
611+ A = a .* ones (N)' + ones (M) .* b' - C'
612+
613+ return 0.5 * norm (A[A .> 0 ], 2 )^ 2 - ϵ * (dot (nu, a) + dot (mu, b))
614+ end
615+
616+ # computes minimizing directions, update γ
617+ function search_dir! (a, b, da, db)
618+ P = a * ones (N)' .+ ones (M) * b' .- C'
619+
620+ σ = 1.0 * sparse (P .>= 0 )
621+ γ = sparse (max .(P, 0 ) ./ ϵ)
622+
623+ G = vcat (
624+ hcat (spdiagm (0 => σ * ones (N)), σ),
625+ hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
626+ )
627+
628+ h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
629+
630+ x = cg (G + δ * I, - ϵ * h)
631+
632+ da = x[1 : M]
633+ return db = x[(M + 1 ): end ]
634+ end
635+
636+ function search_dir (a, b)
637+ P = a * ones (N)' .+ ones (M) * b' .- C'
638+
639+ σ = 1.0 * sparse (P .>= 0 )
640+ γ = sparse (max .(P, 0 ) ./ ϵ)
641+
642+ G = vcat (
643+ hcat (spdiagm (0 => σ * ones (N)), σ),
644+ hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
645+ )
646+
647+ h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
648+
649+ x = cg (G + δ * I, - ϵ * h)
650+
651+ return x[1 : M], x[(M + 1 ): end ]
652+ end
653+
654+ # computes optimal maginitude in the minimizing directions
655+ function search_t (a, b, da, db, θ)
656+ d = ϵ * dot (γ, (da .* ones (N)' .+ ones (M) .* db' )) - ϵ * (dot (da, nu) + dot (db, mu))
657+
658+ ϕ₀ = DualObjective (a, b)
659+ t = 1
660+
661+ while DualObjective (a + t * da, b + t * db) >= ϕ₀ + t * θ * d
662+ t *= κ
663+
664+ if t < 1e-15
665+ # @warn "@ i = $i, t = $t , armijo did not converge"
666+ break
667+ end
668+ end
669+ return t
670+ end
671+
672+ for i in 1 : maxiter
673+
674+ # search_dir!(a, b, da, db)
675+ da, db = search_dir (a, b)
676+
677+ t = search_t (a, b, da, db, θ)
678+
679+ a += t * da
680+ b += t * db
681+
682+ err1 = norm (γ * ones (N) - nu, Inf )
683+ err2 = norm (sparse (γ' ) * ones (M) - mu, Inf )
684+
685+ if err1 <= tol && err2 <= tol
686+ converged = true
687+ @warn " Converged @ i = $i with marginal errors: \n err1 = $err1 , err2 = $err2 \n "
688+ break
689+ elseif i == maxiter
690+ @warn " Not Converged with errors:\n err1 = $err1 , err2 = $err2 \n "
691+ end
692+
693+ @debug " t = $t "
694+ @debug " marginal @ i = $i : err1 = $err1 , err2 = $err2 "
695+ end
696+
697+ if ! converged
698+ @warn " SemiSmooth Newton algorithm did not converge"
699+ end
700+
701+ return sparse (γ' )
702+ end
703+
572704end
0 commit comments