Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
julia = "1"
ArrayInterface = "3.1.17"
StatsAPI = "1"
julia = "1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down
1 change: 1 addition & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Distances

using ArrayInterface: device, AbstractDevice, GPU
using LinearAlgebra
using Statistics
import StatsAPI: pairwise, pairwise!
Expand Down
21 changes: 21 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ _eltype(::Type{Union{Missing, T}}) where {T} = Union{Missing, T}
__eltype(::Base.HasEltype, a) = _eltype(eltype(a))
__eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a)))


abstract type AbstractEvaluateStrategy end
struct Vectorization <: AbstractEvaluateStrategy end
struct ScalarMapReduce <: AbstractEvaluateStrategy end

# Infer the optimal evaluation strategy based on given array types and distance type.
function infer_evaluate_strategy(d::PreMetric, a, b)
da, db = device(a), device(b)
return _infer_evaluate_strategy(d::PreMetric, da, db)
end
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::AbstractDevice) = ScalarMapReduce()
# when one of the input are scalar types
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::Nothing) = ScalarMapReduce()
@inline _infer_evaluate_strategy(d::PreMetric, ::Nothing, ::AbstractDevice) = ScalarMapReduce()
@inline _infer_evaluate_strategy(d::PreMetric, ::Nothing, ::Nothing) = ScalarMapReduce()
# It is way slower to use scalar indexing if any of the given array is GPU array
@inline _infer_evaluate_strategy(d::PreMetric, ::AbstractDevice, ::GPU) = Vectorization()
@inline _infer_evaluate_strategy(d::PreMetric, ::GPU, ::AbstractDevice) = Vectorization()
@inline _infer_evaluate_strategy(d::PreMetric, ::GPU, ::GPU) = Vectorization()


# Generic column-wise evaluation

"""
Expand Down
42 changes: 29 additions & 13 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,26 @@ result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, ::Nothing) where {Ta,Tb}
result_type(dist::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} =
typeof(_evaluate(dist, oneunit(Ta), oneunit(Tb), oneunit(_eltype(p))))

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b)
_evaluate(d, a, b, parameters(d))
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p=parameters(d))
_evaluate(infer_evaluate_strategy(d, a, b), d, a, b, p)
end
for M in (metrics..., weightedmetrics...)
@eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b)
Comment on lines +226 to +227
Copy link
Contributor Author

@johnnychen94 johnnychen94 Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for loop is moved from L327-L329.

end

# breaks the implementation into eval_start, eval_op, eval_reduce and eval_end
function _evaluate(::Vectorization, d::UnionMetrics, a, b, ::Nothing)
map_op(x,y) = eval_op(d, x, y)
reduce_op(x, y) = eval_reduce(d, x, y)
eval_end(d, reduce(reduce_op, map_op.(a, b); init=eval_start(d, a, b)))
end
function _evaluate(::Vectorization, d::UnionMetrics, a, b, p)
map_op(x,y,p) = eval_op(d, x, y, p)
reduce_op(x, y) = eval_reduce(d, x, y)
eval_end(d, reduce(reduce_op, map_op.(a, b, p); init=eval_start(d, a, b)))
end

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a, b, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -239,7 +252,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, ::Nothing)
end
return eval_end(d, s)
end
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a::AbstractArray, b::AbstractArray, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -263,7 +276,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
end
end

Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a, b, p)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first collection has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -279,7 +292,7 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a, b, p)
end
return eval_end(d, s)
end
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray)
Base.@propagate_inbounds function _evaluate(::ScalarMapReduce, d::UnionMetrics, a::AbstractArray, b::AbstractArray, p::AbstractArray)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand Down Expand Up @@ -308,8 +321,8 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
end
end

_evaluate(dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
function _evaluate(dist::UnionMetrics, a::Number, b::Number, p)
_evaluate(::ScalarMapReduce, dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
function _evaluate(::ScalarMapReduce, dist::UnionMetrics, a::Number, b::Number, p)
length(p) != 1 && throw(DimensionMismatch("inputs are scalars but parameters have length $(length(p))."))
eval_end(dist, eval_op(dist, a, b, first(p)))
end
Expand All @@ -324,10 +337,6 @@ _eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} =
eval_reduce(::UnionMetrics, s1, s2) = s1 + s2
eval_end(::UnionMetrics, s) = s

for M in (metrics..., weightedmetrics...)
@eval @inline (dist::$M)(a, b) = _evaluate(dist, a, b, parameters(dist))
end

# Euclidean
@inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi)
eval_end(::Euclidean, s) = sqrt(s)
Expand Down Expand Up @@ -373,7 +382,14 @@ totalvariation(a, b) = TotalVariation()(a, b)
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
# if only NaN, will output NaN
Base.@propagate_inbounds eval_start(::Chebyshev, a, b) = abs(first(a) - first(b))
Base.@propagate_inbounds function eval_start(d::Chebyshev, a, b)
T = result_type(d, a, b)
if any(isnan, a) || any(isnan, b)
return convert(T, NaN)
else
zero(T) # lower bound of chebyshev distance
end
end
Comment on lines +386 to +393
Copy link
Contributor Author

@johnnychen94 johnnychen94 Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rewritten to support CuArray; scalar indexing first is slow for CuArray.

chebyshev(a, b) = Chebyshev()(a, b)

# Minkowski
Expand Down
9 changes: 8 additions & 1 deletion test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ end

function test_metricity(dist, x, y, z)
@testset "Test metricity of $(typeof(dist))" begin
@test dist(x, y) == evaluate(dist, x, y)
d = dist(x, y)
@test d == evaluate(dist, x, y)
if d isa Distances.UnionMetrics
# currently only UnionMetrics supports this strategy trait
d_vec = Distances._evaluate(Distances.Vectorization(), dist, x, y, Distances.parameters(dist))
d_scalar = Distances._evaluate(Distances.ScalarMapReduce(), dist, x, y, Distances.parameters(dist))
@test d_vec ≈ d_scalar
end

dxy = dist(x, y)
dxz = dist(x, z)
Expand Down