@@ -263,30 +263,38 @@ a sparse matrix.
263263See also: [`ot_cost`](@ref), [`emd`](@ref)
264264"""
265265function ot_plan (_, μ:: DiscreteNonParametric , ν:: DiscreteNonParametric )
266- # unpack the probabilities of the two distributions
266+ # Unpack the probabilities of the two distributions
267+ # Note: support of `DiscreteNonParametric` is sorted
267268 μprobs = probs (μ)
268269 νprobs = probs (ν)
269-
270- # create the iterator
271- # note: support of `DiscreteNonParametric` is sorted
272- iter = Discrete1DOTIterator (μprobs, νprobs)
273-
274- # create arrays for the indices of the two histograms and the optimal flow between the
275- # corresponding points
276- n = length (iter)
277- I = Vector {Int} (undef, n)
278- J = Vector {Int} (undef, n)
279- W = Vector {Base.promote_eltype(μprobs, νprobs)} (undef, n)
280-
281- # compute the sparse optimal transport plan
282- @inbounds for (idx, (i, j, w)) in enumerate (iter)
283- I[idx] = i
284- J[idx] = j
285- W[idx] = w
270+ T = Base. promote_eltype (μprobs, νprobs)
271+
272+ return if μprobs isa FillArrays. AbstractFill &&
273+ νprobs isa FillArrays. AbstractFill &&
274+ length (μprobs) == length (νprobs)
275+ # Special case: discrete uniform distributions of the same "size"
276+ k = length (μprobs)
277+ sparse (1 : k, 1 : k, T (first (μprobs)), k, k)
278+ else
279+ # Generic case
280+ # Create the iterator
281+ iter = Discrete1DOTIterator (μprobs, νprobs)
282+
283+ # create arrays for the indices of the two histograms and the optimal flow between the
284+ # corresponding points
285+ n = length (iter)
286+ I = Vector {Int} (undef, n)
287+ J = Vector {Int} (undef, n)
288+ W = Vector {T} (undef, n)
289+
290+ # compute the sparse optimal transport plan
291+ @inbounds for (idx, (i, j, w)) in enumerate (iter)
292+ I[idx] = i
293+ J[idx] = j
294+ W[idx] = w
295+ end
296+ sparse (I, J, W, length (μprobs), length (νprobs))
286297 end
287- γ = sparse (I, J, W, length (μprobs), length (νprobs))
288-
289- return γ
290298end
291299
292300"""
@@ -305,45 +313,50 @@ A pre-computed optimal transport `plan` may be provided.
305313See also: [`ot_plan`](@ref), [`emd2`](@ref)
306314"""
307315function ot_cost (c, μ:: DiscreteNonParametric , ν:: DiscreteNonParametric ; plan= nothing )
308- return _ot_cost (c, μ, ν, plan)
309- end
310-
311- # compute cost from scratch if no plan is provided
312- function _ot_cost (c, μ:: DiscreteNonParametric , ν:: DiscreteNonParametric , :: Nothing )
313- # unpack the probabilities of the two distributions
316+ # Extract support and probabilities of discrete distributions
317+ # Note: support of `DiscreteNonParametric` is sorted
318+ μsupport = support (μ)
319+ νsupport = support (ν)
314320 μprobs = probs (μ)
315321 νprobs = probs (ν)
316322
323+ return if μprobs isa FillArrays. AbstractFill &&
324+ νprobs isa FillArrays. AbstractFill &&
325+ length (μprobs) == length (νprobs)
326+ # Special case: discrete uniform distributions of the same "size"
327+ # In this case we always just compute `sum(c.(μsupport .- νsupport))` and scale it
328+ # We use pairwise summation and avoid allocations
329+ # (https://github.com/JuliaLang/julia/pull/31020)
330+ T = Base. promote_eltype (μprobs, νprobs)
331+ T (first (μprobs)) *
332+ sum (Broadcast. instantiate (Broadcast. broadcasted (c, μsupport, νsupport)))
333+ else
334+ # Generic case
335+ _ot_cost (c, μsupport, μprobs, νsupport, νprobs, plan)
336+ end
337+ end
338+
339+ # compute cost from scratch if no plan is provided
340+ function _ot_cost (c, μsupport, μprobs, νsupport, νprobs, :: Nothing )
317341 # create the iterator
318- # note: support of `DiscreteNonParametric` is sorted
319342 iter = Discrete1DOTIterator (μprobs, νprobs)
320343
321344 # compute the cost
322- μsupport = support (μ)
323- νsupport = support (ν)
324- cost = sum (w * c (μsupport[i], νsupport[j]) for (i, j, w) in iter)
325-
326- return cost
345+ return sum (w * c (μsupport[i], νsupport[j]) for (i, j, w) in iter)
327346end
328347
329348# if a sparse plan is provided, we just iterate through the non-zero entries
330- function _ot_cost (
331- c, μ:: DiscreteNonParametric , ν:: DiscreteNonParametric , plan:: SparseMatrixCSC
332- )
349+ function _ot_cost (c, μsupport, _, νsupport, _, plan:: SparseMatrixCSC )
333350 # extract non-zero flows
334351 I, J, W = findnz (plan)
335352
336353 # compute the cost
337- μsupport = support (μ)
338- νsupport = support (ν)
339- cost = sum (w * c (μsupport[i], νsupport[j]) for (i, j, w) in zip (I, J, W))
340-
341- return cost
354+ return sum (w * c (μsupport[i], νsupport[j]) for (i, j, w) in zip (I, J, W))
342355end
343356
344357# fallback: compute cost matrix (probably often faster to compute cost from scratch)
345- function _ot_cost (c, μ :: DiscreteNonParametric , ν :: DiscreteNonParametric , plan)
346- return dot (plan, StatsBase. pairwise (c, support (μ), support (ν) ))
358+ function _ot_cost (c, μsupport, _, νsupport, _ , plan)
359+ return dot (plan, StatsBase. pairwise (c, μsupport, νsupport ))
347360end
348361
349362# ###############
0 commit comments