@@ -13,7 +13,7 @@ export sinkhorn, sinkhorn2
1313export emd, emd2
1414export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1515export sinkhorn_unbalanced, sinkhorn_unbalanced2
16- export quadreg
16+ export ot_reg_plan
1717
1818const MOI = MathOptInterface
1919
@@ -506,137 +506,12 @@ function sinkhorn_barycenter(
506506 return u_all[1 , :] .* (K_all[1 ] * v_all[1 , :])
507507end
508508
509- """
510- quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)
511-
512- Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and quadratic regularization parameter `ϵ`,
513- using the semismooth Newton algorithm [Lorenz 2016].
514-
515- This implementation makes use of IterativeSolvers.jl and SparseArrays.jl.
516-
517- Parameters:\n
518- θ: starting Armijo parameter.\n
519- tol: tolerance of marginal error.\n
520- maxiter: maximum interation number.\n
521- κ: control parameter of Armijo.\n
522- δ: small constant for the numerical stability of conjugate gradient iterative solver.\n
523-
524- Tips:
525- If the algorithm does not converge, try some different values of θ.
526-
527- Reference:
528- Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. arXiv preprint arXiv:1903.01112v4.
529- """
530- function quadreg (mu, nu, C, ϵ; θ= 0.1 , tol= 1e-5 , maxiter= 50 , κ= 0.5 , δ= 1e-5 )
531- if ! (sum (mu) ≈ sum (nu))
532- throw (ArgumentError (" Error: mu and nu must lie in the simplex" ))
533- end
534-
535- N = length (mu)
536- M = length (nu)
537-
538- # initialize dual potentials as uniforms
539- a = ones (M) ./ M
540- b = ones (N) ./ N
541- γ = spzeros (M, N)
542-
543- da = spzeros (M)
544- db = spzeros (N)
545-
546- converged = false
547-
548- function DualObjective (a, b)
549- A = a .* ones (N)' + ones (M) .* b' - C'
550-
551- return 0.5 * norm (A[A .> 0 ], 2 )^ 2 - ϵ * (dot (nu, a) + dot (mu, b))
552- end
553-
554- # computes minimizing directions, update γ
555- function search_dir! (a, b, da, db)
556- P = a * ones (N)' .+ ones (M) * b' .- C'
557-
558- σ = 1.0 * sparse (P .>= 0 )
559- γ = sparse (max .(P, 0 ) ./ ϵ)
560-
561- G = vcat (
562- hcat (spdiagm (0 => σ * ones (N)), σ),
563- hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
564- )
565-
566- h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
567-
568- x = cg (G + δ * I, - ϵ * h)
569-
570- da = x[1 : M]
571- return db = x[(M + 1 ): end ]
572- end
573-
574- function search_dir (a, b)
575- P = a * ones (N)' .+ ones (M) * b' .- C'
576-
577- σ = 1.0 * sparse (P .>= 0 )
578- γ = sparse (max .(P, 0 ) ./ ϵ)
579-
580- G = vcat (
581- hcat (spdiagm (0 => σ * ones (N)), σ),
582- hcat (sparse (σ' ), spdiagm (0 => sparse (σ' ) * ones (M))),
583- )
584-
585- h = vcat (γ * ones (N) - nu, sparse (γ' ) * ones (M) - mu)
586-
587- x = cg (G + δ * I, - ϵ * h)
588-
589- return x[1 : M], x[(M + 1 ): end ]
590- end
591-
592- # computes optimal maginitude in the minimizing directions
593- function search_t (a, b, da, db, θ)
594- d = ϵ * dot (γ, (da .* ones (N)' .+ ones (M) .* db' )) - ϵ * (dot (da, nu) + dot (db, mu))
595-
596- ϕ₀ = DualObjective (a, b)
597- t = 1
598-
599- while DualObjective (a + t * da, b + t * db) >= ϕ₀ + t * θ * d
600- t *= κ
601-
602- if t < 1e-15
603- # @warn "@ i = $i, t = $t , armijo did not converge"
604- break
605- end
606- end
607- return t
608- end
609-
610- for i in 1 : maxiter
611-
612- # search_dir!(a, b, da, db)
613- da, db = search_dir (a, b)
614-
615- t = search_t (a, b, da, db, θ)
616-
617- a += t * da
618- b += t * db
619-
620- err1 = norm (γ * ones (N) - nu, Inf )
621- err2 = norm (sparse (γ' ) * ones (M) - mu, Inf )
622-
623- if err1 <= tol && err2 <= tol
624- converged = true
625- @warn " Converged @ i = $i with marginal errors: \n err1 = $err1 , err2 = $err2 \n "
626- break
627- elseif i == maxiter
628- @warn " Not Converged with errors:\n err1 = $err1 , err2 = $err2 \n "
629- end
630-
631- @debug " t = $t "
632- @debug " marginal @ i = $i : err1 = $err1 , err2 = $err2 "
633- end
634-
635- if ! converged
636- @warn " SemiSmooth Newton algorithm did not converge"
509+ function ot_reg_plan (mu, nu, C, eps; reg_func = " L2" , method = " lorenz" , kwargs... )
510+ if (reg_func == " L2" ) && (method == " lorenz" )
511+ return quadreg (mu, nu, C, eps; kwargs... )
512+ else
513+ @warn " Unimplemented"
637514 end
638-
639- return sparse (γ' )
640515end
641516
642517end
0 commit comments