Skip to content

Commit e03f7d6

Browse files
johnnychen94KristofferC
authored andcommitted
move implementations into type overloading (aka. functor) (#139)
* make evaluate a general API * BhattacharyyaDist and HellingerDist * Bregman * Haversine * SqMahalanobis and Mahalanobis * metrics in metrics.jl CorrDist is excluded from `UnionMetrics` since it's a simple wrap on CosineDist * remove specification of pairwise and colwise on CorrDist The previous specification is needed to pass a centralized input, now we don't need it anymore * metrics in wmetrics.jl * colwise and pairwise * update README.md * update test * rename metric_list to metrics * revert auto-formatted spaces changed parts: * spaces after type annotation, e.g, `b::AbstractMatrix = a` not changed parts: * additional spaces at the end of line * spaces between operations and comma, e.g., `a + b` and `(a, b)` * update format * rollback auto-format on whitespaces * test if there are any ambiguities This PR fixes all the ambiguities as a good start, future PRs may not break this.
1 parent 293457e commit e03f7d6

File tree

10 files changed

+115
-120
lines changed

10 files changed

+115
-120
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ Each distance corresponds to a *distance type*. You can always compute a certain
5454

5555
```julia
5656
r = evaluate(dist, x, y)
57+
r = dist(x, y)
5758
```
5859

5960
Here, dist is an instance of a distance type. For example, the type for Euclidean distance is ``Euclidean`` (more distance types will be introduced in the next section), then you can compute the Euclidean distance between ``x`` and ``y`` as
6061

6162
```julia
6263
r = evaluate(Euclidean(), x, y)
64+
r = Euclidean()(x, y)
6365
```
6466

6567
Common distances also come with convenient functions for distance evaluation. For example, you may also compute Euclidean distance between two vectors as below

src/bhattacharyya.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,11 @@ bhattacharyya_coeff(a::T, b::T) where {T <: Number} = throw("Bhattacharyya coeff
3737

3838

3939
# Bhattacharyya distance
40-
evaluate(dist::BhattacharyyaDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
41-
bhattacharyya(a::AbstractVector, b::AbstractVector) = evaluate(BhattacharyyaDist(), a, b)
42-
evaluate(dist::BhattacharyyaDist, a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
43-
bhattacharyya(a::T, b::T) where {T <: Number} = evaluate(BhattacharyyaDist(), a, b)
40+
(::BhattacharyyaDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
41+
(::BhattacharyyaDist)(a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
42+
bhattacharyya(a, b) = BhattacharyyaDist()(a, b)
4443

4544
# Hellinger distance
46-
evaluate(dist::HellingerDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
47-
hellinger(a::AbstractVector, b::AbstractVector) = evaluate(HellingerDist(), a, b)
48-
evaluate(dist::HellingerDist, a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
49-
hellinger(a::T, b::T) where {T <: Number} = evaluate(HellingerDist(), a, b)
45+
(::HellingerDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
46+
(::HellingerDist)(a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
47+
hellinger(a, b) = HellingerDist()(a, b)

src/bregman.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)
2323

2424
# Evaluation fuction
25-
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
25+
function (dist::Bregman)(p::AbstractVector, q::AbstractVector)
2626
# Create cache vals.
2727
FP_val = dist.F(p);
2828
FQ_val = dist.F(q);
@@ -45,4 +45,4 @@ function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
4545
end
4646

4747
# Convenience function.
48-
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y)
48+
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = Bregman(F, ∇, inner)(x, y)

src/generic.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ abstract type SemiMetric <: PreMetric end
2121
#
2222
abstract type Metric <: SemiMetric end
2323

24+
evaluate(dist::PreMetric, a, b) = dist(a, b)
2425

2526
# Generic functions
2627

@@ -41,7 +42,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs
4142
n = size(b, 2)
4243
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
4344
@inbounds for j = 1:n
44-
r[j] = evaluate(metric, a, view(b, :, j))
45+
r[j] = metric(a, view(b, :, j))
4546
end
4647
r
4748
end
@@ -50,7 +51,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
5051
n = size(a, 2)
5152
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
5253
@inbounds for j = 1:n
53-
r[j] = evaluate(metric, view(a, :, j), b)
54+
r[j] = metric(view(a, :, j), b)
5455
end
5556
r
5657
end
@@ -59,7 +60,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
5960
n = get_common_ncols(a, b)
6061
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
6162
@inbounds for j = 1:n
62-
r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
63+
r[j] = metric(view(a, :, j), view(b, :, j))
6364
end
6465
r
6566
end
@@ -97,7 +98,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric,
9798
@inbounds for j = 1:size(b, 2)
9899
bj = view(b, :, j)
99100
for i = 1:size(a, 2)
100-
r[i, j] = evaluate(metric, view(a, :, i), bj)
101+
r[i, j] = metric(view(a, :, i), bj)
101102
end
102103
end
103104
r
@@ -109,7 +110,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
109110
@inbounds for j = 1:n
110111
aj = view(a, :, j)
111112
for i = (j + 1):n
112-
r[i, j] = evaluate(metric, view(a, :, i), aj)
113+
r[i, j] = metric(view(a, :, i), aj)
113114
end
114115
r[j, j] = 0
115116
for i = 1:(j - 1)

src/haversine.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212

1313
const VecOrLengthTwoTuple{T} = Union{AbstractVector{T}, NTuple{2, T}}
1414

15-
function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
15+
function (dist::Haversine)(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
1616
length(x) == length(y) == 2 || haversine_error()
1717

1818
@inbounds begin
@@ -33,6 +33,6 @@ function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTupl
3333
2 * dist.radius * asin( min(a, one(a)) ) # take care of floating point errors
3434
end
3535

36-
haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = evaluate(Haversine(radius), x, y)
36+
haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = Haversine(radius)(x, y)
3737

3838
@noinline haversine_error() = throw(ArgumentError("expected both inputs to have length 2 in Haversine distance"))

src/mahalanobis.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::Type, ::Type) where {T} = T
1313

1414
# SqMahalanobis
1515

16-
function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
16+
function (dist::SqMahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
1717
if length(a) != length(b)
1818
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
1919
end
@@ -23,7 +23,7 @@ function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector)
2323
return dot(z, Q * z)
2424
end
2525

26-
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b)
26+
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)
2727

2828
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
2929
Q = dist.qmat
@@ -83,11 +83,11 @@ end
8383

8484
# Mahalanobis
8585

86-
function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
87-
sqrt(evaluate(SqMahalanobis(dist.qmat), a, b))
86+
function (dist::Mahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
87+
sqrt(SqMahalanobis(dist.qmat)(a, b))
8888
end
8989

90-
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b)
90+
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)
9191

9292
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
9393
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))

0 commit comments

Comments
 (0)