@@ -534,4 +534,138 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
534534 end
535535end
536536
537+
538+ """
539+ quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
540+
541+ Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
542+ using the semismooth Newton algorithm [Lorenz 2016].
543+
544+ This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
545+
546+ Parameters:\n
547+ θ: starting Armijo parameter.\n
548+ tol: tolerance of marginal error.\n
549+ maxiter: maximum interation number.\n
550+ κ: control parameter of Armijo.\n
551+ δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
552+
553+ Tips:
554+ If the algorithm does not converge, try some different values of θ.
555+
556+ Reference:
557+ Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
558+ """
559+ function quadreg (mu, nu, C, ϵ; θ= 0.1 , tol= 1e-5 , maxiter= 50 , κ= 0.5 , δ= 1e-5 )
560+ if ! (sum (mu) ≈ sum (nu))
561+ throw (ArgumentError (" Error: mu and nu must lie in the simplex" ))
562+ end
563+
564+ N = length (mu)
565+ M = length (nu)
566+
567+ # initialize dual potentials as uniforms
568+ a = ones (M) ./ M
569+ b = ones (N) ./ N
570+ γ = spzeros (M, N)
571+
572+ da = spzeros (M)
573+ db = spzeros (N)
574+
575+ converged = false
576+
577+ function DualObjective (a, b)
578+ A = a .* ones (N)' + ones (M) .* b' - C'
579+
580+ return 0.5 * norm (A[A .> 0 ], 2 )^ 2 - ϵ * (dot (nu, a) + dot (mu, b))
581+ end
582+
583+ # computes minimizing directions, update γ
584+ function search_dir! (a, b, da, db)
585+ P = a * ones (N)' .+ ones (M) * b' .- C'
586+
587+ σ = 1.0 * sparse (P .>= 0 )
588+ γ = sparse (max .(P, 0 ) ./ ϵ)
589+
590+ G = vcat (
591+ hcat (spdiagm (0 => σ * ones (N)), σ),
592+ hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
593+ )
594+
595+ h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
596+
597+ x = cg (G + δ * I, - ϵ * h)
598+
599+ da = x[1 : M]
600+ return db = x[(M + 1 ): end ]
601+ end
602+
603+ function search_dir (a, b)
604+ P = a * ones (N)' .+ ones (M) * b' .- C'
605+
606+ σ = 1.0 * sparse (P .>= 0 )
607+ γ = sparse (max .(P, 0 ) ./ ϵ)
608+
609+ G = vcat (
610+ hcat (spdiagm (0 => σ * ones (N)), σ),
611+ hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
612+ )
613+
614+ h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
615+
616+ x = cg (G + δ * I, - ϵ * h)
617+
618+ return x[1 : M], x[(M + 1 ): end ]
619+ end
620+
621+ # computes optimal maginitude in the minimizing directions
622+ function search_t (a, b, da, db, θ)
623+ d = ϵ * dot (γ, (da .* ones (N)' .+ ones (M) .* db' )) - ϵ * (dot (da, nu) + dot (db, mu))
624+
625+ ϕ₀ = DualObjective (a, b)
626+ t = 1
627+
628+ while DualObjective (a + t * da, b + t * db) >= ϕ₀ + t * θ * d
629+ t *= κ
630+
631+ if t < 1e-15
632+ # @warn "@ i = $i, t = $t , armijo did not converge"
633+ break
634+ end
635+ end
636+ return t
637+ end
638+
639+ for i in 1 : maxiter
640+
641+ # search_dir!(a, b, da, db)
642+ da, db = search_dir (a, b)
643+
644+ t = search_t (a, b, da, db, θ)
645+
646+ a += t * da
647+ b += t * db
648+
649+ err1 = norm (γ * ones (N) - nu, Inf )
650+ err2 = norm (sparse (γ' ) * ones (M) - mu, Inf )
651+
652+ if err1 <= tol && err2 <= tol
653+ converged = true
654+ @warn " Converged @ i = $i with marginal errors: \n err1 = $err1 , err2 = $err2 \n "
655+ break
656+ elseif i == maxiter
657+ @warn " Not Converged with errors:\n err1 = $err1 , err2 = $err2 \n "
658+ end
659+
660+ @debug " t = $t "
661+ @debug " marginal @ i = $i : err1 = $err1 , err2 = $err2 "
662+ end
663+
664+ if ! converged
665+ @warn " SemiSmooth Newton algorithm did not converge"
666+ end
667+
668+ return sparse (γ' )
669+ end
670+
537671end
0 commit comments