Skip to content

Commit 90a2f42

Browse files
committed
use juliaformatter
1 parent ca49b26 commit 90a2f42

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

src/OptimalTransport.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ function emd2(μ, ν, C, optimizer; map=nothing)
115115
# check dimensions
116116
size(C) == (length(μ), length(ν)) ||
117117
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
118-
size(map) == size(C) ||
119-
error("optimal transport map `map` and cost matrix `C` must be of the same size")
118+
size(map) == size(C) || error(
119+
"optimal transport map `map` and cost matrix `C` must be of the same size",
120+
)
120121
map
121122
end
122123
return dot(γ, C)
@@ -225,8 +226,9 @@ function sinkhorn2(μ, ν, C, ε; map=nothing, kwargs...)
225226
# check dimensions
226227
size(C) == (length(μ), length(ν)) ||
227228
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
228-
size(map) == size(C) ||
229-
error("optimal transport map `map` and cost matrix `C` must be of the same size")
229+
size(map) == size(C) || error(
230+
"optimal transport map `map` and cost matrix `C` must be of the same size",
231+
)
230232
map
231233
end
232234
return dot(γ, C)
@@ -328,8 +330,9 @@ function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; map=nothing, kwargs...)
328330
# check dimensions
329331
size(C) == (length(μ), length(ν)) ||
330332
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
331-
size(map) == size(C) ||
332-
error("optimal transport map `map` and cost matrix `C` must be of the same size")
333+
size(map) == size(C) || error(
334+
"optimal transport map `map` and cost matrix `C` must be of the same size",
335+
)
333336
map
334337
end
335338
return dot(γ, C)

src/pot.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,18 @@ function sinkhorn_unbalanced2(
155155
end
156156

157157
function smooth_ot_dual(
158-
mu, nu, C, eps, reg_type = "l2", method = "L-BFGS-B", tol = 1e-9, max_iter = 500, verbose = false
158+
mu, nu, C, eps, reg_type="l2", method="L-BFGS-B", tol=1e-9, max_iter=500, verbose=false
159159
)
160-
return pot.smooth.smooth_ot_dual(nu, mu, PyReverseDims(C), eps, reg_type = reg_type, method = method, stopThr = tol, numItermax = max_iter)'
160+
return pot.smooth.smooth_ot_dual(
161+
nu,
162+
mu,
163+
PyReverseDims(C),
164+
eps;
165+
reg_type=reg_type,
166+
method=method,
167+
stopThr=tol,
168+
numItermax=max_iter,
169+
)'
161170
end
162171

163172
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,6 @@ end
189189
γ = quadreg(μ, ν, C, eps)
190190
γ_pot = sparse(POT.smooth_ot_dual(μ, ν, C, eps))
191191
# need to use a larger tolerance here because of a quirk with the POT solver
192-
@test norm- γ_pot, Inf) < 0.5e-4
192+
@test norm- γ_pot, Inf) < 0.5e-4
193193
end
194194
end

0 commit comments

Comments
 (0)