@@ -14,10 +14,12 @@ export sinkhorn, sinkhorn2
1414export emd, emd2
1515export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1616export sinkhorn_unbalanced, sinkhorn_unbalanced2
17- export quadreg
17+ export ot_reg_plan
1818
1919const MOI = MathOptInterface
2020
21+ include (" ot_reg.jl" )
22+
2123function __init__ ()
2224 @require PyCall = " 438e738f-606a-5dbb-bf0a-cddfbfd45ab0" begin
2325 export POT
@@ -539,137 +541,12 @@ function sinkhorn_barycenter(
539541 return u_all[1 , :] .* (K_all[1 ] * v_all[1 , :])
540542end
541543
542- """
543- quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
544-
545- Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
546- using the semismooth Newton algorithm [Lorenz 2016].
547-
548- This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
549-
550- Parameters:\n
551- θ: starting Armijo parameter.\n
552- tol: tolerance of marginal error.\n
553- maxiter: maximum interation number.\n
554- κ: control parameter of Armijo.\n
555- δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
556-
557- Tips:
558- If the algorithm does not converge, try some different values of θ.
559-
560- Reference:
561- Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
562- """
563- function quadreg (mu, nu, C, ϵ; θ= 0.1 , tol= 1e-5 , maxiter= 50 , κ= 0.5 , δ= 1e-5 )
564- if ! (sum (mu) ≈ sum (nu))
565- throw (ArgumentError (" Error: mu and nu must lie in the simplex" ))
566- end
567-
568- N = length (mu)
569- M = length (nu)
570-
571- # initialize dual potentials as uniforms
572- a = ones (M) ./ M
573- b = ones (N) ./ N
574- γ = spzeros (M, N)
575-
576- da = spzeros (M)
577- db = spzeros (N)
578-
579- converged = false
580-
581- function DualObjective (a, b)
582- A = a .* ones (N)' + ones (M) .* b' - C'
583-
584- return 0.5 * norm (A[A .> 0 ], 2 )^ 2 - ϵ * (dot (nu, a) + dot (mu, b))
585- end
586-
587- # computes minimizing directions, update γ
588- function search_dir! (a, b, da, db)
589- P = a * ones (N)' .+ ones (M) * b' .- C'
590-
591- σ = 1.0 * sparse (P .>= 0 )
592- γ = sparse (max .(P, 0 ) ./ ϵ)
593-
594- G = vcat (
595- hcat (spdiagm (0 => σ * ones (N)), σ),
596- hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
597- )
598-
599- h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
600-
601- x = cg (G + δ * I, - ϵ * h)
602-
603- da = x[1 : M]
604- return db = x[(M + 1 ): end ]
605- end
606-
607- function search_dir (a, b)
608- P = a * ones (N)' .+ ones (M) * b' .- C'
609-
610- σ = 1.0 * sparse (P .>= 0 )
611- γ = sparse (max .(P, 0 ) ./ ϵ)
612-
613- G = vcat (
614- hcat (spdiagm (0 => σ * ones (N)), σ),
615- hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
616- )
617-
618- h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
619-
620- x = cg (G + δ * I, - ϵ * h)
621-
622- return x[1 : M], x[(M + 1 ): end ]
623- end
624-
625- # computes optimal maginitude in the minimizing directions
626- function search_t (a, b, da, db, θ)
627- d = ϵ * dot (γ, (da .* ones (N)' .+ ones (M) .* db' )) - ϵ * (dot (da, nu) + dot (db, mu))
628-
629- ϕ₀ = DualObjective (a, b)
630- t = 1
631-
632- while DualObjective (a + t * da, b + t * db) >= ϕ₀ + t * θ * d
633- t *= κ
634-
635- if t < 1e-15
636- # @warn "@ i = $i, t = $t , armijo did not converge"
637- break
638- end
639- end
640- return t
641- end
642-
643- for i in 1 : maxiter
644-
645- # search_dir!(a, b, da, db)
646- da, db = search_dir (a, b)
647-
648- t = search_t (a, b, da, db, θ)
649-
650- a += t * da
651- b += t * db
652-
653- err1 = norm (γ * ones (N) - nu, Inf )
654- err2 = norm (sparse (γ' ) * ones (M) - mu, Inf )
655-
656- if err1 <= tol && err2 <= tol
657- converged = true
658- @warn " Converged @ i = $i with marginal errors: \n err1 = $err1 , err2 = $err2 \n "
659- break
660- elseif i == maxiter
661- @warn " Not Converged with errors:\n err1 = $err1 , err2 = $err2 \n "
662- end
663-
664- @debug " t = $t "
665- @debug " marginal @ i = $i : err1 = $err1 , err2 = $err2 "
666- end
667-
668- if ! converged
669- @warn " SemiSmooth Newton algorithm did not converge"
544+ function ot_reg_plan (mu, nu, C, eps; reg_func = " L2" , method = " lorenz" , kwargs... )
545+ if (reg_func == " L2" ) && (method == " lorenz" )
546+ return quadreg (mu, nu, C, eps; kwargs... )
547+ else
548+ @warn " Unimplemented"
670549 end
671-
672- return sparse (γ' )
673550end
674551
675552end
0 commit comments