Skip to content

Commit b3a1a35

Browse files
authored
fix type instability in Jaccard for float input (#44)
* fix type instability in Jaccard and add some inbounds
1 parent b608c83 commit b3a1a35

File tree

2 files changed

+48
-38
lines changed

2 files changed

+48
-38
lines changed

src/metrics.jl

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,22 @@ end
176176

177177
# Jaccard
178178

179-
@inline eval_start(::Jaccard, a::AbstractArray, b::AbstractArray) = 0, 0
179+
@inline eval_start(::Jaccard, a::AbstractArray{Bool}, b::AbstractArray{Bool}) = 0, 0
180+
@inline eval_start{T}(::Jaccard, a::AbstractArray{T}, b::AbstractArray{T}) = zero(T), zero(T)
180181
@inline function eval_op(::Jaccard, s1, s2)
181-
denominator = max(s1, s2)
182-
numerator = min(s1, s2)
183-
numerator, denominator
182+
abs_m = abs(s1 - s2)
183+
abs_p = abs(s1 + s2)
184+
abs_p - abs_m, abs_p + abs_m
184185
end
185186
@inline function eval_reduce(::Jaccard, s1, s2)
186-
a = s1[1] + s2[1]
187-
b = s1[2] + s2[2]
187+
@inbounds a = s1[1] + s2[1]
188+
@inbounds b = s1[2] + s2[2]
188189
a, b
189190
end
190-
@inline eval_end(::Jaccard, a) = 1 - (a[1]/a[2])
191+
@inline function eval_end(::Jaccard, a)
192+
@inbounds v = 1 - (a[1]/a[2])
193+
return v
194+
end
191195
jaccard(a::AbstractArray, b::AbstractArray) = evaluate(Jaccard(), a, b)
192196

193197
# Tanimoto
@@ -201,15 +205,17 @@ jaccard(a::AbstractArray, b::AbstractArray) = evaluate(Jaccard(), a, b)
201205
tt, tf, ft, ff
202206
end
203207
@inline function eval_reduce(::RogersTanimoto, s1, s2)
204-
a = s1[1] + s2[1]
205-
b = s1[2] + s2[2]
206-
c = s1[3] + s2[3]
207-
d = s1[4] + s1[4]
208+
@inbounds begin
209+
a = s1[1] + s2[1]
210+
b = s1[2] + s2[2]
211+
c = s1[3] + s2[3]
212+
d = s1[4] + s1[4]
213+
end
208214
a, b, c, d
209215
end
210216
@inline function eval_end(::RogersTanimoto, a)
211-
numerator = 2(a[2] + a[3])
212-
denominator = a[1] + a[4] + 2(a[2] + a[3])
217+
@inbounds numerator = 2(a[2] + a[3])
218+
@inbounds denominator = a[1] + a[4] + 2(a[2] + a[3])
213219
numerator / denominator
214220
end
215221
rogerstanimoto{T <: Bool}(a::AbstractArray{T}, b::AbstractArray{T}) = evaluate(RogersTanimoto(), a, b)
@@ -240,13 +246,13 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
240246
m, n = get_pairwise_dims(r, a)
241247
At_mul_B!(r, a, a)
242248
sa2 = sumsq_percol(a)
243-
for j = 1 : n
249+
@inbounds for j = 1 : n
244250
for i = 1 : j-1
245-
@inbounds r[i,j] = r[j,i]
251+
r[i,j] = r[j,i]
246252
end
247-
@inbounds r[j,j] = 0
253+
r[j,j] = 0
248254
for i = j+1 : n
249-
@inbounds r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
255+
r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
250256
end
251257
end
252258
r
@@ -258,10 +264,10 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::Abs
258264
At_mul_B!(r, a, b)
259265
sa2 = sumsq_percol(a)
260266
sb2 = sumsq_percol(b)
261-
for j = 1 : nb
267+
@inbounds for j = 1 : nb
262268
for i = 1 : na
263-
@inbounds v = sa2[i] + sb2[j] - 2 * r[i,j]
264-
@inbounds r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
269+
v = sa2[i] + sb2[j] - 2 * r[i,j]
270+
r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
265271
end
266272
end
267273
r
@@ -271,14 +277,14 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
271277
m, n = get_pairwise_dims(r, a)
272278
At_mul_B!(r, a, a)
273279
sa2 = sumsq_percol(a)
274-
for j = 1 : n
280+
@inbounds for j = 1 : n
275281
for i = 1 : j-1
276-
@inbounds r[i,j] = r[j,i]
282+
r[i,j] = r[j,i]
277283
end
278284
@inbounds r[j,j] = 0
279285
for i = j+1 : n
280-
@inbounds v = sa2[i] + sa2[j] - 2 * r[i,j]
281-
@inbounds r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
286+
v = sa2[i] + sa2[j] - 2 * r[i,j]
287+
r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
282288
end
283289
end
284290
r
@@ -302,13 +308,13 @@ function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
302308
m, n = get_pairwise_dims(r, a)
303309
At_mul_B!(r, a, a)
304310
ra = sqrt!(sumsq_percol(a))
305-
for j = 1 : n
311+
@inbounds for j = 1 : n
306312
@simd for i = j+1 : n
307-
@inbounds r[i,j] = max(1 - r[i,j] / (ra[i] * ra[j]), 0)
313+
r[i,j] = max(1 - r[i,j] / (ra[i] * ra[j]), 0)
308314
end
309-
@inbounds r[j,j] = 0
315+
r[j,j] = 0
310316
for i = 1 : j-1
311-
@inbounds r[i,j] = r[j,i]
317+
r[i,j] = r[j,i]
312318
end
313319
end
314320
r

test/test_dists.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
108108
@test whamming(a, b, w) == sum((a .!= b) .* w)
109109
end
110110

111+
@inferred evaluate(Jaccard(), rand(3), rand(3))
112+
@inferred evaluate(Jaccard(), [1,2,3], [1,2,3])
113+
@inferred evaluate(Jaccard(), [true, false, true], [false, true, true])
114+
111115
end # testset
112116

113117

@@ -188,28 +192,28 @@ q = rand(12)
188192
px = x ./ sum(x)
189193
py = y ./ sum(y)
190194
expected_bc_x_y = sum(sqrt(px .* py))
191-
@test Distances.bhattacharyya_coeff(x, y) expected_bc_x_y
192-
@test bhattacharyya(x, y) (-log(expected_bc_x_y))
193-
@test hellinger(x, y) sqrt(1 - expected_bc_x_y)
195+
@test Distances.bhattacharyya_coeff(x, y) expected_bc_x_y
196+
@test bhattacharyya(x, y) (-log(expected_bc_x_y))
197+
@test hellinger(x, y) sqrt(1 - expected_bc_x_y)
194198

195199

196200

197201
pa = a ./ sum(a)
198202
pb = b ./ sum(b)
199203
expected_bc_a_b = sum(sqrt(pa .* pb))
200-
@test Distances.bhattacharyya_coeff(a, b) expected_bc_a_b
201-
@test bhattacharyya(a, b) (-log(expected_bc_a_b))
202-
@test hellinger(a, b) sqrt(1 - expected_bc_a_b)
204+
@test Distances.bhattacharyya_coeff(a, b) expected_bc_a_b
205+
@test bhattacharyya(a, b) (-log(expected_bc_a_b))
206+
@test hellinger(a, b) sqrt(1 - expected_bc_a_b)
203207

204208
pp = p ./ sum(p)
205209
pq = q ./ sum(q)
206210
expected_bc_p_q = sum(sqrt(pp .* pq))
207-
@test Distances.bhattacharyya_coeff(p, q) expected_bc_p_q
208-
@test bhattacharyya(p, q) (-log(expected_bc_p_q))
209-
@test hellinger(p, q) sqrt(1 - expected_bc_p_q)
211+
@test Distances.bhattacharyya_coeff(p, q) expected_bc_p_q
212+
@test bhattacharyya(p, q) (-log(expected_bc_p_q))
213+
@test hellinger(p, q) sqrt(1 - expected_bc_p_q)
210214

211215
# Ensure it is semimetric
212-
@test bhattacharyya(x, y) bhattacharyya(y, x)
216+
@test bhattacharyya(x, y) bhattacharyya(y, x)
213217

214218
end #testset
215219

0 commit comments

Comments
 (0)