Skip to content

Commit 010f66d

Browse files
committed
added ot_reg_cost
1 parent 161a804 commit 010f66d

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/OptimalTransport.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export sinkhorn, sinkhorn2
1414
export emd, emd2
1515
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1616
export sinkhorn_unbalanced, sinkhorn_unbalanced2
17-
export ot_reg_plan
17+
export ot_reg_plan, ot_reg_cost
1818

1919
const MOI = MathOptInterface
2020

@@ -567,6 +567,25 @@ function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
567567
end
568568
end
569569

570+
"""
571+
ot_reg_cost(mu, nu, C, eps; reg_func = "L2", method = "lorenz", kwargs...)
572+
573+
Compute the optimal transport cost between `mu` and `nu` for optimal transport with a
574+
general choice of regulariser `math Ω(γ)`.
575+
576+
See also: [`ot_reg_plan`](@ref)
577+
578+
"""
579+
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
580+
γ = if (reg_func == "L2") && (method == "lorenz")
581+
quadreg(mu, nu, C, eps; kwargs...)
582+
else
583+
@warn "Unimplemented"
584+
nothing
585+
end
586+
return dot(γ, C)
587+
end
588+
570589

571590
"""
572591
quadreg(mu, nu, C, ϵ; θ = 0.1, tol = 1e-5,maxiter = 50,κ = 0.5,δ = 1e-5)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ end
201201
γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps; max_iter=5000))
202202
# need to use a larger tolerance here because of a quirk with the POT solver
203203
@test norm- γ_pot, Inf) < 1e-4
204+
c = ot_reg_cost(μ, ν, C, eps; reg_func="L2", method="lorenz")
205+
c_pot = dot(γ_pot, C)
206+
@test c c_pot atol = 1e-4
204207
end
205208
end
206209

0 commit comments

Comments
 (0)