Skip to content

Commit 840b285

Browse files
authored
fix 0.7 deprecations and improve performance for "normal" arrays and views (#91)
* split up evaluate to enable better codegen for Arrays * fix for 0.7
1 parent 8bf4ced commit 840b285

File tree

10 files changed

+101
-61
lines changed

10 files changed

+101
-61
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
julia 0.6
2+
Compat 0.54.0

benchmark/benchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ SUITE["colwise"] = BenchmarkGroup()
5353
function evaluate_colwise(dist, x, y)
5454
n = size(x, 2)
5555
T = typeof(evaluate(dist, x[:, 1], y[:, 1]))
56-
r = Vector{T}(n)
56+
r = Vector{T}(uninitialized, n)
5757
for j = 1:n
5858
r[j] = evaluate(dist, x[:, j], y[:, j])
5959
end

src/Distances.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ __precompile__()
22

33
module Distances
44

5+
using Compat
6+
using Compat.LinearAlgebra
7+
58
export
69
# generic types/functions
710
PreMetric,

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ end
101101
function sumsq_percol(a::AbstractMatrix{T}) where {T}
102102
m = size(a, 1)
103103
n = size(a, 2)
104-
r = Vector{T}(n)
104+
r = Vector{T}(uninitialized, n)
105105
for j = 1:n
106106
aj = view(a, :, j)
107107
r[j] = dot(aj, aj)
@@ -113,7 +113,7 @@ function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1, T
113113
m = size(a, 1)
114114
n = size(a, 2)
115115
T = typeof(one(T1) * one(T2))
116-
r = Vector{T}(n)
116+
r = Vector{T}(uninitialized, n)
117117
for j = 1:n
118118
aj = view(a, :, j)
119119
s = zero(T)

src/generic.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,26 @@ result_type(::PreMetric, ::AbstractArray, ::AbstractArray) = Float64
3232
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
3333
n = size(b, 2)
3434
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
35-
for j = 1:n
36-
@inbounds r[j] = evaluate(metric, a, view(b, :, j))
35+
@inbounds for j = 1:n
36+
r[j] = evaluate(metric, a, view(b, :, j))
3737
end
3838
r
3939
end
4040

4141
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
4242
n = size(a, 2)
4343
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
44-
for j = 1:n
45-
@inbounds r[j] = evaluate(metric, view(a, :, j), b)
44+
@inbounds for j = 1:n
45+
r[j] = evaluate(metric, view(a, :, j), b)
4646
end
4747
r
4848
end
4949

5050
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
5151
n = get_common_ncols(a, b)
5252
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
53-
for j = 1:n
54-
@inbounds r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
53+
@inbounds for j = 1:n
54+
r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
5555
end
5656
r
5757
end
@@ -62,19 +62,19 @@ end
6262

6363
function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
6464
n = get_common_ncols(a, b)
65-
r = Vector{result_type(metric, a, b)}(n)
65+
r = Vector{result_type(metric, a, b)}(uninitialized, n)
6666
colwise!(r, metric, a, b)
6767
end
6868

6969
function colwise(metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
7070
n = size(b, 2)
71-
r = Vector{result_type(metric, a, b)}(n)
71+
r = Vector{result_type(metric, a, b)}(uninitialized, n)
7272
colwise!(r, metric, a, b)
7373
end
7474

7575
function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
7676
n = size(a, 2)
77-
r = Vector{result_type(metric, a, b)}(n)
77+
r = Vector{result_type(metric, a, b)}(uninitialized, n)
7878
colwise!(r, metric, a, b)
7979
end
8080

@@ -85,10 +85,10 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
8585
na = size(a, 2)
8686
nb = size(b, 2)
8787
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
88-
for j = 1:size(b, 2)
88+
@inbounds for j = 1:size(b, 2)
8989
bj = view(b, :, j)
9090
for i = 1:size(a, 2)
91-
@inbounds r[i, j] = evaluate(metric, view(a, :, i), bj)
91+
r[i, j] = evaluate(metric, view(a, :, i), bj)
9292
end
9393
end
9494
r
@@ -101,14 +101,14 @@ end
101101
function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
102102
n = size(a, 2)
103103
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
104-
for j = 1:n
104+
@inbounds for j = 1:n
105105
aj = view(a, :, j)
106106
for i = (j + 1):n
107-
@inbounds r[i, j] = evaluate(metric, view(a, :, i), aj)
107+
r[i, j] = evaluate(metric, view(a, :, i), aj)
108108
end
109-
@inbounds r[j, j] = 0
109+
r[j, j] = 0
110110
for i = 1:(j - 1)
111-
@inbounds r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric
111+
r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric
112112
end
113113
end
114114
r
@@ -117,12 +117,12 @@ end
117117
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
118118
m = size(a, 2)
119119
n = size(b, 2)
120-
r = Matrix{result_type(metric, a, b)}(m, n)
120+
r = Matrix{result_type(metric, a, b)}(uninitialized, m, n)
121121
pairwise!(r, metric, a, b)
122122
end
123123

124124
function pairwise(metric::PreMetric, a::AbstractMatrix)
125125
n = size(a, 2)
126-
r = Matrix{result_type(metric, a, a)}(n, n)
126+
r = Matrix{result_type(metric, a, a)}(uninitialized, n, n)
127127
pairwise!(r, metric, a)
128128
end

src/metrics.jl

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,25 +145,48 @@ SqEuclidean() = SqEuclidean(0)
145145
#
146146
###########################################################
147147

148-
function evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray)
149-
if length(a) != length(b)
148+
const ArraySlice{T} = SubArray{T,1,Array{T,2},Tuple{Base.Slice{Base.OneTo{Int}},Int},true}
149+
150+
# Specialized for Arrays and avoids a branch on the size
151+
@inline Base.@propagate_inbounds function evaluate(d::UnionMetrics, a::Union{Array, ArraySlice}, b::Union{Array, ArraySlice})
152+
@boundscheck if length(a) != length(b)
150153
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
151154
end
152155
if length(a) == 0
153156
return zero(result_type(d, a, b))
154157
end
155-
s = eval_start(d, a, b)
156-
if size(a) == size(b)
158+
@inbounds begin
159+
s = eval_start(d, a, b)
157160
@simd for I in eachindex(a, b)
158-
@inbounds ai = a[I]
159-
@inbounds bi = b[I]
161+
ai = a[I]
162+
bi = b[I]
160163
s = eval_reduce(d, s, eval_op(d, ai, bi))
161164
end
162-
else
163-
for (Ia, Ib) in zip(eachindex(a), eachindex(b))
164-
@inbounds ai = a[Ia]
165-
@inbounds bi = b[Ib]
166-
s = eval_reduce(d, s, eval_op(d, ai, bi))
165+
return eval_end(d, s)
166+
end
167+
end
168+
169+
@inline function evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray)
170+
@boundscheck if length(a) != length(b)
171+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
172+
end
173+
if length(a) == 0
174+
return zero(result_type(d, a, b))
175+
end
176+
@inbounds begin
177+
s = eval_start(d, a, b)
178+
if size(a) == size(b)
179+
@simd for I in eachindex(a, b)
180+
ai = a[I]
181+
bi = b[I]
182+
s = eval_reduce(d, s, eval_op(d, ai, bi))
183+
end
184+
else
185+
for (Ia, Ib) in zip(eachindex(a), eachindex(b))
186+
ai = a[Ia]
187+
bi = b[Ib]
188+
s = eval_reduce(d, s, eval_op(d, ai, bi))
189+
end
167190
end
168191
end
169192
return eval_end(d, s)
@@ -200,7 +223,7 @@ cityblock(a::T, b::T) where {T <: Number} = evaluate(Cityblock(), a, b)
200223
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
201224
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
202225
# if only NaN, will output NaN
203-
@inline eval_start(::Chebyshev, a::AbstractArray, b::AbstractArray) = abs(a[1] - b[1])
226+
@inline Base.@propagate_inbounds eval_start(::Chebyshev, a::AbstractArray, b::AbstractArray) = abs(a[1] - b[1])
204227
chebyshev(a::AbstractArray, b::AbstractArray) = evaluate(Chebyshev(), a, b)
205228
chebyshev(a::T, b::T) where {T <: Number} = evaluate(Chebyshev(), a, b)
206229

@@ -218,7 +241,7 @@ hamming(a::AbstractArray, b::AbstractArray) = evaluate(Hamming(), a, b)
218241
hamming(a::T, b::T) where {T <: Number} = evaluate(Hamming(), a, b)
219242

220243
# Cosine dist
221-
function eval_start(::CosineDist, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real}
244+
@inline function eval_start(::CosineDist, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real}
222245
zero(T), zero(T), zero(T)
223246
end
224247
@inline eval_op(::CosineDist, ai, bi) = ai * bi, ai * ai, bi * bi
@@ -236,6 +259,8 @@ cosine_dist(a::AbstractArray, b::AbstractArray) = evaluate(CosineDist(), a, b)
236259
# Correlation Dist
237260
_centralize(x::AbstractArray) = x .- mean(x)
238261
evaluate(::CorrDist, a::AbstractArray, b::AbstractArray) = cosine_dist(_centralize(a), _centralize(b))
262+
# Ambiguity resolution
263+
evaluate(::CorrDist, a::Array, b::Array) = cosine_dist(_centralize(a), _centralize(b))
239264
corr_dist(a::AbstractArray, b::AbstractArray) = evaluate(CorrDist(), a, b)
240265
result_type(::CorrDist, a::AbstractArray, b::AbstractArray) = result_type(CosineDist(), a, b)
241266

@@ -255,7 +280,7 @@ kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a,
255280
gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b)
256281

257282
# RenyiDivergence
258-
function eval_start(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real}
283+
@inline Base.@propagate_inbounds function eval_start(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real}
259284
zero(T), zero(T), T(sum(a)), T(sum(b))
260285
end
261286

@@ -316,7 +341,7 @@ end
316341
js_divergence(a::AbstractArray, b::AbstractArray) = evaluate(JSDivergence(), a, b)
317342

318343
# SpanNormDist
319-
function eval_start(::SpanNormDist, a::AbstractArray, b::AbstractArray)
344+
@inline Base.@propagate_inbounds function eval_start(::SpanNormDist, a::AbstractArray, b::AbstractArray)
320345
a[1] - b[1], a[1] - b[1]
321346
end
322347
@inline eval_op(::SpanNormDist, ai, bi) = ai - bi

src/wmetrics.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,37 +45,39 @@ end
4545
function result_type(dist::UnionWeightedMetrics, ::AbstractArray{T1}, ::AbstractArray{T2}) where {T1, T2}
4646
typeof(evaluate(dist, one(T1), one(T2)))
4747
end
48-
function eval_start(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
48+
@inline function eval_start(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
4949
zero(result_type(d, a, b))
5050
end
5151
eval_end(d::UnionWeightedMetrics, s) = s
5252

5353

5454

55-
function evaluate(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
56-
if length(a) != length(b)
55+
@inline function evaluate(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
56+
@boundscheck if length(a) != length(b)
5757
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
5858
end
59-
if length(a) != length(d.weights)
59+
@boundscheck if length(a) != length(d.weights)
6060
throw(DimensionMismatch("arrays have length $(length(a)) but weights have length $(length(d.weights))."))
6161
end
6262
if length(a) == 0
6363
return zero(result_type(d, a, b))
6464
end
65-
s = eval_start(d, a, b)
66-
if size(a) == size(b)
67-
@simd for I in eachindex(a, b, d.weights)
68-
@inbounds ai = a[I]
69-
@inbounds bi = b[I]
70-
@inbounds wi = d.weights[I]
71-
s = eval_reduce(d, s, eval_op(d, ai, bi, wi))
72-
end
73-
else
74-
for (Ia, Ib, Iw) in zip(eachindex(a), eachindex(b), eachindex(d.weights))
75-
@inbounds ai = a[Ia]
76-
@inbounds bi = b[Ib]
77-
@inbounds wi = d.weights[Iw]
78-
s = eval_reduce(d, s, eval_op(d, ai, bi, wi))
65+
@inbounds begin
66+
s = eval_start(d, a, b)
67+
if size(a) == size(b)
68+
@simd for I in eachindex(a, b, d.weights)
69+
ai = a[I]
70+
bi = b[I]
71+
wi = d.weights[I]
72+
s = eval_reduce(d, s, eval_op(d, ai, bi, wi))
73+
end
74+
else
75+
for (Ia, Ib, Iw) in zip(eachindex(a), eachindex(b), eachindex(d.weights))
76+
ai = a[Ia]
77+
bi = b[Ib]
78+
wi = d.weights[Iw]
79+
s = eval_reduce(d, s, eval_op(d, ai, bi, wi))
80+
end
7981
end
8082
end
8183
return eval_end(d, s)

test/F64.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
# dummy type wrapping a Float64 used in tests
2-
struct F64 <: Real
2+
struct F64 <: AbstractFloat
33
x::Float64
44
end
5+
F64(x::F64) = x
56

67
# operations
7-
for op in (:+, :-)
8+
for op in (:+, :-, :sin, :cos, :asin, :acos)
89
@eval Base.$op(a::F64) = F64($op(a.x))
910
end
10-
for op in (:+, :-, :*, :/)
11+
for op in (:+, :-, :*, :/, :atan2)
1112
@eval Base.$op(a::F64, b::F64) = F64($op(a.x, b.x))
1213
end
13-
for op in (:zero, :one)
14+
for op in (:zero, :one,)
1415
@eval Base.$op(::Type{F64}) = F64($op(Float64))
1516
end
16-
Base.rand(rng::AbstractRNG, ::Type{F64}) = F64(rand())
17+
18+
if VERSION.minor >= 7
19+
Random.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{F64}}) = F64(rand(rng))
20+
else
21+
Base.rand(rng::AbstractRNG, ::Type{F64}) = F64(rand())
22+
end
1723
Base.sqrt(a::F64) = F64(sqrt(a.x))
1824
Base.:^(a::F64, b::Number) = F64(a.x^b)
1925
Base.:^(a::F64, b::Int) = F64(a.x^b)
@@ -32,8 +38,9 @@ Base.eps(::Type{F64}) = eps(Float64)
3238
# promotion
3339
Base.promote_type(::Type{Float32}, ::Type{F64}) = Float64 # for eig
3440
Base.promote_type(::Type{Float64}, ::Type{F64}) = Float64 # for vecnorm
35-
Base.promote(a::F64, b::T) where {T <: Number} = a, F64(b)
36-
Base.promote(a::T, b::F64) where {T <: Number} = F64(a), b
41+
Base.promote(a::F64, b::T) where {T <: Number} = a, F64(float(b))
42+
Base.promote(a::T, b::F64) where {T <: Number} = F64(float(a)), b
43+
3744
Base.convert(::Type{F64}, a::F64) = a
3845
Base.convert(::Type{Float64}, a::F64) = a.x
3946
Base.convert(::Type{F64}, a::T) where {T <: Number} = F64(a)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Distances
22
using Compat.Test
3+
using Compat.LinearAlgebra
4+
using Compat.Random
35

46
include("F64.jl")
57
include("test_dists.jl")

test/test_dists.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ end
158158
w = ones(4)
159159
@test sqeuclidean(x, y) wsqeuclidean(x, y, w)
160160

161-
w = rand(size(x))
161+
w = rand(Float64, size(x))
162162
@test wsqeuclidean(x, y, w) dot((x - vec(y)).^2, w)
163163
@test weuclidean(x, y, w) == sqrt(wsqeuclidean(x, y, w))
164164
@test wcityblock(x, y, w) dot(abs.(x - vec(y)), w)

0 commit comments

Comments
 (0)