@@ -72,8 +72,7 @@ function result_type(d::SqMahalanobis, ::Type{T1}, ::Type{T2}) where {T1,T2}
72
72
return typeof (z * zero (eltype (d. qmat)) * z)
73
73
end
74
74
75
- # SqMahalanobis
76
-
75
+ # TODO : merge the following two once we lift the lower bound for julia (above v1.4?)
77
76
function (dist:: SqMahalanobis )(a:: AbstractVector , b:: AbstractVector )
78
77
if length (a) != length (b)
79
78
throw (DimensionMismatch (" first array has length $(length (a)) which does not match the length of the second, $(length (b)) ." ))
@@ -83,24 +82,47 @@ function (dist::SqMahalanobis)(a::AbstractVector, b::AbstractVector)
83
82
z = a - b
84
83
return dot (z, Q * z)
85
84
end
85
+ function (dist:: Mahalanobis )(a:: AbstractVector , b:: AbstractVector )
86
+ if length (a) != length (b)
87
+ throw (DimensionMismatch (" first array has length $(length (a)) which does not match the length of the second, $(length (b)) ." ))
88
+ end
86
89
87
- sqmahalanobis (a:: AbstractVector , b:: AbstractVector , Q:: AbstractMatrix ) = SqMahalanobis (Q)(a, b)
88
-
89
- function colwise! (r:: AbstractArray , dist:: SqMahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
90
90
Q = dist. qmat
91
- get_colwise_dims (size (Q, 1 ), r, a, b)
92
91
z = a - b
93
- dot_percol! (r , Q * z, z )
92
+ return sqrt ( dot (z , Q * z) )
94
93
end
95
94
96
- function colwise! (r:: AbstractArray , dist:: SqMahalanobis , a:: AbstractVector , b:: AbstractMatrix )
95
+ sqmahalanobis (a:: AbstractVector , b:: AbstractVector , Q:: AbstractMatrix ) = SqMahalanobis (Q)(a, b)
96
+ mahalanobis (a:: AbstractVector , b:: AbstractVector , Q:: AbstractMatrix ) = Mahalanobis (Q)(a, b)
97
+
98
+ function _colwise! (r, dist, a, b)
97
99
Q = dist. qmat
98
100
get_colwise_dims (size (Q, 1 ), r, a, b)
99
101
z = a .- b
100
102
dot_percol! (r, Q * z, z)
101
103
end
102
104
103
- function _pairwise! (r:: AbstractMatrix , dist:: SqMahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
105
+ function colwise! (r:: AbstractArray , dist:: SqMahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
106
+ _colwise! (r, dist, a, b)
107
+ end
108
+ function colwise! (r:: AbstractArray , dist:: SqMahalanobis , a:: AbstractVector , b:: AbstractMatrix )
109
+ _colwise! (r, dist, a, b)
110
+ end
111
+ function colwise! (r:: AbstractArray , dist:: SqMahalanobis , a:: AbstractMatrix , b:: AbstractVector )
112
+ _colwise! (r, dist, a, b)
113
+ end
114
+
115
+ function colwise! (r:: AbstractArray , dist:: Mahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
116
+ sqrt! (_colwise! (r, dist, a, b))
117
+ end
118
+ function colwise! (r:: AbstractArray , dist:: Mahalanobis , a:: AbstractVector , b:: AbstractMatrix )
119
+ sqrt! (_colwise! (r, dist, a, b))
120
+ end
121
+ function colwise! (r:: AbstractArray , dist:: Mahalanobis , a:: AbstractMatrix , b:: AbstractVector )
122
+ sqrt! (_colwise! (r, dist, a, b))
123
+ end
124
+
125
+ function _pairwise! (r:: AbstractMatrix , dist:: Union{SqMahalanobis,Mahalanobis} , a:: AbstractMatrix , b:: AbstractMatrix )
104
126
Q = dist. qmat
105
127
m, na, nb = get_pairwise_dims (size (Q, 1 ), r, a, b)
106
128
@@ -112,13 +134,13 @@ function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix, b
112
134
113
135
for j = 1 : nb
114
136
@simd for i = 1 : na
115
- @inbounds r[i, j] = max (sa2[i] + sb2[j] - 2 * r[i, j], 0 )
137
+ @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sb2[j] - 2 * r[i, j], 0 ) )
116
138
end
117
139
end
118
140
r
119
141
end
120
142
121
- function _pairwise! (r:: AbstractMatrix , dist:: SqMahalanobis , a:: AbstractMatrix )
143
+ function _pairwise! (r:: AbstractMatrix , dist:: Union{ SqMahalanobis,Mahalanobis} , a:: AbstractMatrix )
122
144
Q = dist. qmat
123
145
m, n = get_pairwise_dims (size (Q, 1 ), r, a)
124
146
@@ -132,33 +154,11 @@ function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis, a::AbstractMatrix)
132
154
end
133
155
r[j, j] = 0
134
156
for i = (j + 1 ): n
135
- @inbounds r[i, j] = max (sa2[i] + sa2[j] - 2 * r[i, j], 0 )
157
+ @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sa2[j] - 2 * r[i, j], 0 ) )
136
158
end
137
159
end
138
160
r
139
161
end
140
162
141
-
142
- # Mahalanobis
143
-
144
- function (dist:: Mahalanobis )(a:: AbstractVector , b:: AbstractVector )
145
- sqrt (SqMahalanobis (dist. qmat, skipchecks = true )(a, b))
146
- end
147
-
148
- mahalanobis (a:: AbstractVector , b:: AbstractVector , Q:: AbstractMatrix ) = Mahalanobis (Q)(a, b)
149
-
150
- function colwise! (r:: AbstractArray , dist:: Mahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
151
- sqrt! (colwise! (r, SqMahalanobis (dist. qmat, skipchecks = true ), a, b))
152
- end
153
-
154
- function colwise! (r:: AbstractArray , dist:: Mahalanobis , a:: AbstractVector , b:: AbstractMatrix )
155
- sqrt! (colwise! (r, SqMahalanobis (dist. qmat, skipchecks = true ), a, b))
156
- end
157
-
158
- function _pairwise! (r:: AbstractMatrix , dist:: Mahalanobis , a:: AbstractMatrix , b:: AbstractMatrix )
159
- sqrt! (_pairwise! (r, SqMahalanobis (dist. qmat, skipchecks = true ), a, b))
160
- end
161
-
162
- function _pairwise! (r:: AbstractMatrix , dist:: Mahalanobis , a:: AbstractMatrix )
163
- sqrt! (_pairwise! (r, SqMahalanobis (dist. qmat, skipchecks = true ), a))
164
- end
163
+ eval_end (:: SqMahalanobis , x) = x
164
+ eval_end (:: Mahalanobis , x) = sqrt (x)
0 commit comments