Skip to content

Commit cb6dcb2

Browse files
authored
refactor *Mahalanobis to avoid unnecessary symmetry checks (#231)
1 parent 06d8361 commit cb6dcb2

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

src/mahalanobis.jl

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ function result_type(d::SqMahalanobis, ::Type{T1}, ::Type{T2}) where {T1,T2}
7272
return typeof(z * zero(eltype(d.qmat)) * z)
7373
end
7474

75-
# SqMahalanobis
76-
75+
# TODO: merge the following two once we lift the lower bound for julia (above v1.4?)
7776
function (dist::SqMahalanobis)(a::AbstractVector, b::AbstractVector)
7877
if length(a) != length(b)
7978
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
@@ -83,24 +82,47 @@ function (dist::SqMahalanobis)(a::AbstractVector, b::AbstractVector)
8382
z = a - b
8483
return dot(z, Q * z)
8584
end
85+
function (dist::Mahalanobis)(a::AbstractVector, b::AbstractVector)
86+
if length(a) != length(b)
87+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
88+
end
8689

87-
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)
88-
89-
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix)
9090
Q = dist.qmat
91-
get_colwise_dims(size(Q, 1), r, a, b)
9291
z = a - b
93-
dot_percol!(r, Q * z, z)
92+
return sqrt(dot(z, Q * z))
9493
end
9594

96-
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractVector, b::AbstractMatrix)
95+
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)
96+
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)
97+
98+
function _colwise!(r, dist, a, b)
9799
Q = dist.qmat
98100
get_colwise_dims(size(Q, 1), r, a, b)
99101
z = a .- b
100102
dot_percol!(r, Q * z, z)
101103
end
102104

103-
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix)
105+
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix)
106+
_colwise!(r, dist, a, b)
107+
end
108+
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractVector, b::AbstractMatrix)
109+
_colwise!(r, dist, a, b)
110+
end
111+
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractVector)
112+
_colwise!(r, dist, a, b)
113+
end
114+
115+
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix)
116+
sqrt!(_colwise!(r, dist, a, b))
117+
end
118+
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractVector, b::AbstractMatrix)
119+
sqrt!(_colwise!(r, dist, a, b))
120+
end
121+
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractVector)
122+
sqrt!(_colwise!(r, dist, a, b))
123+
end
124+
125+
function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix, b::AbstractMatrix)
104126
Q = dist.qmat
105127
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)
106128

@@ -112,13 +134,13 @@ function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix, b
112134

113135
for j = 1:nb
114136
@simd for i = 1:na
115-
@inbounds r[i, j] = max(sa2[i] + sb2[j] - 2 * r[i, j], 0)
137+
@inbounds r[i, j] = eval_end(dist, max(sa2[i] + sb2[j] - 2 * r[i, j], 0))
116138
end
117139
end
118140
r
119141
end
120142

121-
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix)
143+
function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix)
122144
Q = dist.qmat
123145
m, n = get_pairwise_dims(size(Q, 1), r, a)
124146

@@ -132,33 +154,11 @@ function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix)
132154
end
133155
r[j, j] = 0
134156
for i = (j + 1):n
135-
@inbounds r[i, j] = max(sa2[i] + sa2[j] - 2 * r[i, j], 0)
157+
@inbounds r[i, j] = eval_end(dist, max(sa2[i] + sa2[j] - 2 * r[i, j], 0))
136158
end
137159
end
138160
r
139161
end
140162

141-
142-
# Mahalanobis
143-
144-
function (dist::Mahalanobis)(a::AbstractVector, b::AbstractVector)
145-
sqrt(SqMahalanobis(dist.qmat, skipchecks = true)(a, b))
146-
end
147-
148-
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)
149-
150-
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix)
151-
sqrt!(colwise!(r, SqMahalanobis(dist.qmat, skipchecks = true), a, b))
152-
end
153-
154-
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractVector, b::AbstractMatrix)
155-
sqrt!(colwise!(r, SqMahalanobis(dist.qmat, skipchecks = true), a, b))
156-
end
157-
158-
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix)
159-
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat, skipchecks = true), a, b))
160-
end
161-
162-
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis, a::AbstractMatrix)
163-
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat, skipchecks = true), a))
164-
end
163+
eval_end(::SqMahalanobis, x) = x
164+
eval_end(::Mahalanobis, x) = sqrt(x)

test/test_dists.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ end # testset
375375
Q = Q * Q' # make sure Q is positive-definite
376376
@test_throws DimensionMismatch mahalanobis(p, q, Q)
377377
@test_throws DimensionMismatch mahalanobis(q, q, Q)
378+
@test_throws DimensionMismatch sqmahalanobis(p, q, Q)
379+
@test_throws DimensionMismatch sqmahalanobis(q, q, Q)
378380
mat23 = [0.3 0.2 0.0; 0.1 0.0 0.4]
379381
mat22 = [0.3 0.2; 0.1 0.4]
380382
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat23)

0 commit comments

Comments
 (0)