@@ -145,25 +145,48 @@ SqEuclidean() = SqEuclidean(0)
145
145
#
146
146
# ##########################################################
147
147
148
- function evaluate (d:: UnionMetrics , a:: AbstractArray , b:: AbstractArray )
149
- if length (a) != length (b)
148
+ const ArraySlice{T} = SubArray{T,1 ,Array{T,2 },Tuple{Base. Slice{Base. OneTo{Int}},Int},true }
149
+
150
+ # Specialized for Arrays and avoids a branch on the size
151
+ @inline Base. @propagate_inbounds function evaluate (d:: UnionMetrics , a:: Union{Array, ArraySlice} , b:: Union{Array, ArraySlice} )
152
+ @boundscheck if length (a) != length (b)
150
153
throw (DimensionMismatch (" first array has length $(length (a)) which does not match the length of the second, $(length (b)) ." ))
151
154
end
152
155
if length (a) == 0
153
156
return zero (result_type (d, a, b))
154
157
end
155
- s = eval_start (d, a, b)
156
- if size (a) == size ( b)
158
+ @inbounds begin
159
+ s = eval_start (d, a, b)
157
160
@simd for I in eachindex (a, b)
158
- @inbounds ai = a[I]
159
- @inbounds bi = b[I]
161
+ ai = a[I]
162
+ bi = b[I]
160
163
s = eval_reduce (d, s, eval_op (d, ai, bi))
161
164
end
162
- else
163
- for (Ia, Ib) in zip (eachindex (a), eachindex (b))
164
- @inbounds ai = a[Ia]
165
- @inbounds bi = b[Ib]
166
- s = eval_reduce (d, s, eval_op (d, ai, bi))
165
+ return eval_end (d, s)
166
+ end
167
+ end
168
+
169
+ @inline function evaluate (d:: UnionMetrics , a:: AbstractArray , b:: AbstractArray )
170
+ @boundscheck if length (a) != length (b)
171
+ throw (DimensionMismatch (" first array has length $(length (a)) which does not match the length of the second, $(length (b)) ." ))
172
+ end
173
+ if length (a) == 0
174
+ return zero (result_type (d, a, b))
175
+ end
176
+ @inbounds begin
177
+ s = eval_start (d, a, b)
178
+ if size (a) == size (b)
179
+ @simd for I in eachindex (a, b)
180
+ ai = a[I]
181
+ bi = b[I]
182
+ s = eval_reduce (d, s, eval_op (d, ai, bi))
183
+ end
184
+ else
185
+ for (Ia, Ib) in zip (eachindex (a), eachindex (b))
186
+ ai = a[Ia]
187
+ bi = b[Ib]
188
+ s = eval_reduce (d, s, eval_op (d, ai, bi))
189
+ end
167
190
end
168
191
end
169
192
return eval_end (d, s)
@@ -200,7 +223,7 @@ cityblock(a::T, b::T) where {T <: Number} = evaluate(Cityblock(), a, b)
200
223
@inline eval_op (:: Chebyshev , ai, bi) = abs (ai - bi)
201
224
@inline eval_reduce (:: Chebyshev , s1, s2) = max (s1, s2)
202
225
# if only NaN, will output NaN
203
- @inline eval_start (:: Chebyshev , a:: AbstractArray , b:: AbstractArray ) = abs (a[1 ] - b[1 ])
226
+ @inline Base . @propagate_inbounds eval_start (:: Chebyshev , a:: AbstractArray , b:: AbstractArray ) = abs (a[1 ] - b[1 ])
204
227
chebyshev (a:: AbstractArray , b:: AbstractArray ) = evaluate (Chebyshev (), a, b)
205
228
chebyshev (a:: T , b:: T ) where {T <: Number } = evaluate (Chebyshev (), a, b)
206
229
@@ -218,7 +241,7 @@ hamming(a::AbstractArray, b::AbstractArray) = evaluate(Hamming(), a, b)
218
241
hamming (a:: T , b:: T ) where {T <: Number } = evaluate (Hamming (), a, b)
219
242
220
243
# Cosine dist
221
- function eval_start (:: CosineDist , a:: AbstractArray{T} , b:: AbstractArray{T} ) where {T <: Real }
244
+ @inline function eval_start (:: CosineDist , a:: AbstractArray{T} , b:: AbstractArray{T} ) where {T <: Real }
222
245
zero (T), zero (T), zero (T)
223
246
end
224
247
@inline eval_op (:: CosineDist , ai, bi) = ai * bi, ai * ai, bi * bi
@@ -236,6 +259,8 @@ cosine_dist(a::AbstractArray, b::AbstractArray) = evaluate(CosineDist(), a, b)
236
259
# Correlation Dist
237
260
_centralize (x:: AbstractArray ) = x .- mean (x)
238
261
evaluate (:: CorrDist , a:: AbstractArray , b:: AbstractArray ) = cosine_dist (_centralize (a), _centralize (b))
262
+ # Ambiguity resolution
263
+ evaluate (:: CorrDist , a:: Array , b:: Array ) = cosine_dist (_centralize (a), _centralize (b))
239
264
corr_dist (a:: AbstractArray , b:: AbstractArray ) = evaluate (CorrDist (), a, b)
240
265
result_type (:: CorrDist , a:: AbstractArray , b:: AbstractArray ) = result_type (CosineDist (), a, b)
241
266
@@ -255,7 +280,7 @@ kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a,
255
280
gkl_divergence (a:: AbstractArray , b:: AbstractArray ) = evaluate (GenKLDivergence (), a, b)
256
281
257
282
# RenyiDivergence
258
- function eval_start (:: RenyiDivergence , a:: AbstractArray{T} , b:: AbstractArray{T} ) where {T <: Real }
283
+ @inline Base . @propagate_inbounds function eval_start (:: RenyiDivergence , a:: AbstractArray{T} , b:: AbstractArray{T} ) where {T <: Real }
259
284
zero (T), zero (T), T (sum (a)), T (sum (b))
260
285
end
261
286
316
341
js_divergence (a:: AbstractArray , b:: AbstractArray ) = evaluate (JSDivergence (), a, b)
317
342
318
343
# SpanNormDist
319
- function eval_start (:: SpanNormDist , a:: AbstractArray , b:: AbstractArray )
344
+ @inline Base . @propagate_inbounds function eval_start (:: SpanNormDist , a:: AbstractArray , b:: AbstractArray )
320
345
a[1 ] - b[1 ], a[1 ] - b[1 ]
321
346
end
322
347
@inline eval_op (:: SpanNormDist , ai, bi) = ai - bi
0 commit comments