Skip to content

Commit 7db5360

Browse files
authored
relax metrics to Real (#74)
* format files * relax metrics to Real and test it * fix missing ||
1 parent e26201a commit 7db5360

File tree

9 files changed

+507
-448
lines changed

9 files changed

+507
-448
lines changed

src/Distances.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ export
5858
rogerstanimoto,
5959
chebyshev,
6060
minkowski,
61-
mahalanobis,
6261

6362
hamming,
6463
cosine_dist,

src/common.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function get_colwise_dims(d::Int, r::AbstractArray, a::AbstractVector, b::Abstra
6464
end
6565

6666
function get_colwise_dims(d::Int, r::AbstractArray, a::AbstractMatrix, b::AbstractVector)
67-
size(a, 1) == length(b) == d
67+
size(a, 1) == length(b) == d ||
6868
throw(DimensionMismatch("Incorrect vector dimensions."))
6969
length(r) == size(a, 2) || throw(DimensionMismatch("Incorrect size of r."))
7070
return size(a)
@@ -109,10 +109,10 @@ function sumsq_percol(a::AbstractMatrix{T}) where {T}
109109
return r
110110
end
111111

112-
function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1,T2}
112+
function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1, T2}
113113
m = size(a, 1)
114114
n = size(a, 2)
115-
T = typeof(one(T1)*one(T2))
115+
T = typeof(one(T1) * one(T2))
116116
r = Vector{T}(n)
117117
for j = 1:n
118118
aj = view(a, :, j)
@@ -126,16 +126,16 @@ function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1,T2
126126
end
127127

128128
function dot_percol!(r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
129-
m = size(a,1)
130-
n = size(a,2)
131-
size(b) == (m,n) && length(r) == n ||
129+
m = size(a, 1)
130+
n = size(a, 2)
131+
size(b) == (m, n) && length(r) == n ||
132132
throw(DimensionMismatch("Inconsistent array dimensions."))
133133
for j = 1:n
134-
aj = view(a,:,j)
135-
bj = view(b,:,j)
134+
aj = view(a, :, j)
135+
bj = view(b, :, j)
136136
r[j] = dot(aj, bj)
137137
end
138138
return r
139139
end
140140

141-
dot_percol(a::AbstractMatrix, b::AbstractMatrix) = dot_percol!(Vector{Float64}(size(a,2)), a, b)
141+
dot_percol(a::AbstractMatrix, b::AbstractMatrix) = dot_percol!(Vector{Float64}(size(a, 2)), a, b)

src/generic.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ result_type(::PreMetric, ::AbstractArray, ::AbstractArray) = Float64
3232
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
3333
n = size(b, 2)
3434
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
35-
for j = 1 : n
35+
for j = 1:n
3636
@inbounds r[j] = evaluate(metric, a, view(b, :, j))
3737
end
3838
r
@@ -41,7 +41,7 @@ end
4141
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
4242
n = size(a, 2)
4343
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
44-
for j = 1 : n
44+
for j = 1:n
4545
@inbounds r[j] = evaluate(metric, view(a, :, j), b)
4646
end
4747
r
@@ -50,7 +50,7 @@ end
5050
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
5151
n = get_common_ncols(a, b)
5252
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
53-
for j = 1 : n
53+
for j = 1:n
5454
@inbounds r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
5555
end
5656
r
@@ -85,10 +85,10 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
8585
na = size(a, 2)
8686
nb = size(b, 2)
8787
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
88-
for j = 1 : size(b, 2)
89-
bj = view(b,:,j)
90-
for i = 1 : size(a, 2)
91-
@inbounds r[i,j] = evaluate(metric, view(a,:,i), bj)
88+
for j = 1:size(b, 2)
89+
bj = view(b, :, j)
90+
for i = 1:size(a, 2)
91+
@inbounds r[i, j] = evaluate(metric, view(a, :, i), bj)
9292
end
9393
end
9494
r
@@ -101,14 +101,14 @@ end
101101
function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
102102
n = size(a, 2)
103103
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
104-
for j = 1 : n
105-
aj = view(a,:,j)
106-
for i = j+1 : n
107-
@inbounds r[i,j] = evaluate(metric, view(a,:,i), aj)
104+
for j = 1:n
105+
aj = view(a, :, j)
106+
for i = (j + 1):n
107+
@inbounds r[i, j] = evaluate(metric, view(a, :, i), aj)
108108
end
109-
@inbounds r[j,j] = 0
110-
for i = 1 : j-1
111-
@inbounds r[i,j] = r[j,i] # leveraging the symmetry of SemiMetric
109+
@inbounds r[j, j] = 0
110+
for i = 1:(j - 1)
111+
@inbounds r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric
112112
end
113113
end
114114
r

src/mahalanobis.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::AbstractArray, ::AbstractArray) where {T} = T
1313

1414
# SqMahalanobis
1515

16-
function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: AbstractFloat}
16+
function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
1717
if length(a) != length(b)
1818
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
1919
end
@@ -25,22 +25,22 @@ end
2525

2626
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b)
2727

28-
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
28+
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
2929
Q = dist.qmat
3030
m, n = get_colwise_dims(size(Q, 1), r, a, b)
3131
z = a - b
3232
dot_percol!(r, Q * z, z)
3333
end
3434

35-
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: AbstractFloat}
35+
function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
3636
Q = dist.qmat
3737
m, n = get_colwise_dims(size(Q, 1), r, a, b)
3838
z = a .- b
3939
Qz = Q * z
4040
dot_percol!(r, Q * z, z)
4141
end
4242

43-
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
43+
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
4444
Q = dist.qmat
4545
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)
4646

@@ -50,29 +50,29 @@ function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix,
5050
sb2 = dot_percol(b, Qb)
5151
At_mul_B!(r, a, Qb)
5252

53-
for j = 1 : nb
54-
@simd for i = 1 : na
55-
@inbounds r[i,j] = sa2[i] + sb2[j] - 2 * r[i,j]
53+
for j = 1:nb
54+
@simd for i = 1:na
55+
@inbounds r[i, j] = sa2[i] + sb2[j] - 2 * r[i, j]
5656
end
5757
end
5858
r
5959
end
6060

61-
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: AbstractFloat}
61+
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: Real}
6262
Q = dist.qmat
6363
m, n = get_pairwise_dims(size(Q, 1), r, a)
6464

6565
Qa = Q * a
6666
sa2 = dot_percol(a, Qa)
6767
At_mul_B!(r, a, Qa)
6868

69-
for j = 1 : n
70-
for i = 1 : j-1
71-
@inbounds r[i,j] = r[j,i]
69+
for j = 1:n
70+
for i = 1:(j - 1)
71+
@inbounds r[i, j] = r[j, i]
7272
end
73-
r[j,j] = 0
74-
for i = j+1 : n
75-
@inbounds r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
73+
r[j, j] = 0
74+
for i = (j + 1):n
75+
@inbounds r[i, j] = sa2[i] + sa2[j] - 2 * r[i, j]
7676
end
7777
end
7878
r
@@ -81,24 +81,24 @@ end
8181

8282
# Mahalanobis
8383

84-
function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: AbstractFloat}
84+
function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
8585
sqrt(evaluate(SqMahalanobis(dist.qmat), a, b))
8686
end
8787

8888
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b)
8989

90-
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
90+
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
9191
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
9292
end
9393

94-
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: AbstractFloat}
94+
function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
9595
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
9696
end
9797

98-
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: AbstractFloat}
98+
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
9999
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a, b))
100100
end
101101

102-
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: AbstractFloat}
102+
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: Real}
103103
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a))
104104
end

0 commit comments

Comments
 (0)