Skip to content

Commit a1f2c3e

Browse files
authored
Merge pull request #121 from JuliaStats/nl/dims
Add dims argument to pairwise
2 parents 785aaab + 73333aa commit a1f2c3e

File tree

7 files changed

+122
-47
lines changed

7 files changed

+122
-47
lines changed

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ r = euclidean(x, y)
6969

7070
#### Computing distances between corresponding columns
7171

72-
Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of X and Y in one batch, using the ``colwise`` function, as
72+
Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of ``X`` and ``Y`` in one batch, using the ``colwise`` function, as
7373

7474
```julia
7575
r = colwise(dist, X, Y)
@@ -81,31 +81,35 @@ Note that either of ``X`` and ``Y`` can be just a single vector -- then the ``co
8181

8282
#### Computing pairwise distances
8383

84-
Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function computes distances between each pair of columns in ``X`` and ``Y``:
84+
Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function with the ``dims=2`` argument computes distances between each pair of columns in ``X`` and ``Y``:
8585

8686
```julia
87-
R = pairwise(dist, X, Y)
87+
R = pairwise(dist, X, Y, dims=2)
8888
```
8989

9090
In the output, ``R`` is a matrix of size ``(m, n)``, such that ``R[i,j]`` is the distance between ``X[:,i]`` and ``Y[:,j]``. Computing distances for all pairs using ``pairwise`` function is often remarkably faster than evaluting for each pair individually.
9191

9292
If you just want to just compute distances between columns of a matrix ``X``, you can write
9393

9494
```julia
95-
R = pairwise(dist, X)
95+
R = pairwise(dist, X, dims=2)
9696
```
9797

9898
This statement will result in an ``m-by-m`` matrix, where ``R[i,j]`` is the distance between ``X[:,i]`` and ``X[:,j]``.
9999
``pairwise(dist, X)`` is typically more efficient than ``pairwise(dist, X, X)``, as the former will take advantage of the symmetry when ``dist`` is a semi-metric (including metric).
100100

101+
For performance reasons, it is recommended to use matrices with observations in columns (as shown above). Indeed,
102+
the ``Array`` type in Julia is column-major, making it more efficient to access memory column by column. However,
103+
matrices with observations stored in rows are also supported via the argument ``dims=1``.
104+
101105
#### Computing column-wise and pairwise distances inplace
102106

103-
If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax:
107+
If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax (``i`` being either ``1`` or ``2``):
104108

105109
```julia
106110
colwise!(r, dist, X, Y)
107-
pairwise!(R, dist, X, Y)
108-
pairwise!(R, dist, X)
111+
pairwise!(R, dist, X, Y, dims=i)
112+
pairwise!(R, dist, X, dims=i)
109113
```
110114

111115
Please pay attention to the difference, the functions for inplace computation are ``colwise!`` and ``pairwise!`` (instead of ``colwise`` and ``pairwise``).

src/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535
function get_pairwise_dims(r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix)
3636
ma, na = size(a)
3737
mb, nb = size(b)
38-
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b must match."))
38+
ma == mb || throw(DimensionMismatch("The numbers of rows or columns in a and b must match."))
3939
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
4040
return (ma, na, nb)
4141
end

src/generic.jl

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ end
8181

8282
# Generic pairwise evaluation
8383

84-
function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
84+
function _pairwise!(r::AbstractMatrix, metric::PreMetric,
85+
a::AbstractMatrix, b::AbstractMatrix=a)
8586
na = size(a, 2)
8687
nb = size(b, 2)
8788
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
@@ -94,11 +95,7 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
9495
r
9596
end
9697

97-
function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix)
98-
pairwise!(r, metric, a, a)
99-
end
100-
101-
function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
98+
function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
10299
n = size(a, 2)
103100
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
104101
@inbounds for j = 1:n
@@ -114,15 +111,75 @@ function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
114111
r
115112
end
116113

117-
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
118-
m = size(a, 2)
119-
n = size(b, 2)
114+
function deprecated_dims(dims::Union{Nothing,Integer})
115+
if dims === nothing
116+
Base.depwarn("implicit `dims=2` argument now has to be passed explicitly " *
117+
"to specify that distances between columns should be computed",
118+
:pairwise!)
119+
return 2
120+
else
121+
return dims
122+
end
123+
end
124+
125+
function pairwise!(r::AbstractMatrix, metric::PreMetric,
126+
a::AbstractMatrix, b::AbstractMatrix;
127+
dims::Union{Nothing,Integer}=nothing)
128+
dims = deprecated_dims(dims)
129+
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
130+
if dims == 1
131+
na, ma = size(a)
132+
nb, mb = size(b)
133+
ma == mb || throw(DimensionMismatch("The numbers of columns in a and b " *
134+
"must match (got $ma and $mb)."))
135+
else
136+
ma, na = size(a)
137+
mb, nb = size(b)
138+
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b " *
139+
"must match (got $ma and $mb)."))
140+
end
141+
size(r) == (na, nb) ||
142+
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb)))."))
143+
if dims == 1
144+
_pairwise!(r, metric, transpose(a), transpose(b))
145+
else
146+
_pairwise!(r, metric, a, b)
147+
end
148+
end
149+
150+
function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
151+
dims::Union{Nothing,Integer}=nothing)
152+
dims = deprecated_dims(dims)
153+
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
154+
if dims == 1
155+
n, m = size(a)
156+
else
157+
m, n = size(a)
158+
end
159+
size(r) == (n, n) ||
160+
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n)))."))
161+
if dims == 1
162+
_pairwise!(r, metric, transpose(a))
163+
else
164+
_pairwise!(r, metric, a)
165+
end
166+
end
167+
168+
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
169+
dims::Union{Nothing,Integer}=nothing)
170+
dims = deprecated_dims(dims)
171+
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
172+
m = size(a, dims)
173+
n = size(b, dims)
120174
r = Matrix{result_type(metric, a, b)}(undef, m, n)
121-
pairwise!(r, metric, a, b)
175+
pairwise!(r, metric, a, b, dims=dims)
122176
end
123177

124-
function pairwise(metric::PreMetric, a::AbstractMatrix)
125-
n = size(a, 2)
178+
function pairwise(metric::PreMetric, a::AbstractMatrix;
179+
dims::Union{Nothing,Integer}=nothing)
180+
dims = deprecated_dims(dims)
181+
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
182+
n = size(a, dims)
126183
r = Matrix{result_type(metric, a, a)}(undef, n, n)
127-
pairwise!(r, metric, a)
184+
pairwise!(r, metric, a, dims=dims)
128185
end

src/mahalanobis.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b
4040
dot_percol!(r, Q * z, z)
4141
end
4242

43-
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
43+
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
44+
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
4445
Q = dist.qmat
4546
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)
4647

@@ -58,7 +59,8 @@ function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix,
5859
r
5960
end
6061

61-
function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: Real}
62+
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
63+
a::AbstractMatrix) where {T <: Real}
6264
Q = dist.qmat
6365
m, n = get_pairwise_dims(size(Q, 1), r, a)
6466

@@ -95,10 +97,12 @@ function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::
9597
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
9698
end
9799

98-
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
99-
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a, b))
100+
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
101+
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
102+
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a, b))
100103
end
101104

102-
function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: Real}
103-
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a))
105+
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
106+
a::AbstractMatrix) where {T <: Real}
107+
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a))
104108
end

src/metrics.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b)
462462
###########################################################
463463

464464
# SqEuclidean
465-
function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
465+
function _pairwise!(r::AbstractMatrix, dist::SqEuclidean,
466+
a::AbstractMatrix, b::AbstractMatrix)
466467
mul!(r, a', b)
467468
sa2 = sum(abs2, a, dims=1)
468469
sb2 = sum(abs2, b, dims=1)
@@ -498,7 +499,7 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::A
498499
r
499500
end
500501

501-
function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
502+
function _pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
502503
m, n = get_pairwise_dims(r, a)
503504
mul!(r, a', a)
504505
sa2 = sumsq_percol(a)
@@ -531,7 +532,8 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
531532
end
532533

533534
# Euclidean
534-
function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::AbstractMatrix)
535+
function _pairwise!(r::AbstractMatrix, dist::Euclidean,
536+
a::AbstractMatrix, b::AbstractMatrix)
535537
m, na, nb = get_pairwise_dims(r, a, b)
536538
mul!(r, a', b)
537539
sa2 = sumsq_percol(a)
@@ -558,7 +560,7 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::Abs
558560
r
559561
end
560562

561-
function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
563+
function _pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
562564
m, n = get_pairwise_dims(r, a)
563565
mul!(r, a', a)
564566
sa2 = sumsq_percol(a)
@@ -586,7 +588,8 @@ end
586588

587589
# CosineDist
588590

589-
function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::AbstractMatrix)
591+
function _pairwise!(r::AbstractMatrix, dist::CosineDist,
592+
a::AbstractMatrix, b::AbstractMatrix)
590593
m, na, nb = get_pairwise_dims(r, a, b)
591594
mul!(r, a', b)
592595
ra = sqrt!(sumsq_percol(a))
@@ -598,7 +601,7 @@ function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::Ab
598601
end
599602
r
600603
end
601-
function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
604+
function _pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
602605
m, n = get_pairwise_dims(r, a)
603606
mul!(r, a', a)
604607
ra = sqrt!(sumsq_percol(a))
@@ -623,9 +626,10 @@ end
623626
function colwise!(r::AbstractVector, dist::CorrDist, a::AbstractVector, b::AbstractMatrix)
624627
colwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
625628
end
626-
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix, b::AbstractMatrix)
627-
pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
629+
function _pairwise!(r::AbstractMatrix, dist::CorrDist,
630+
a::AbstractMatrix, b::AbstractMatrix)
631+
_pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
628632
end
629-
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
630-
pairwise!(r, CosineDist(), _centralize_colwise(a))
633+
function _pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
634+
_pairwise!(r, CosineDist(), _centralize_colwise(a))
631635
end

src/wmetrics.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ whamming(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(Weight
117117
###########################################################
118118

119119
# SqEuclidean
120-
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
120+
function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
121+
a::AbstractMatrix, b::AbstractMatrix)
121122
w = dist.weights
122123
m, na, nb = get_pairwise_dims(length(w), r, a, b)
123124

@@ -131,7 +132,8 @@ function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatr
131132
end
132133
r
133134
end
134-
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix)
135+
function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
136+
a::AbstractMatrix)
135137
w = dist.weights
136138
m, n = get_pairwise_dims(length(w), r, a)
137139

@@ -157,9 +159,10 @@ end
157159
function colwise!(r::AbstractArray, dist::WeightedEuclidean, a::AbstractVector, b::AbstractMatrix)
158160
sqrt!(colwise!(r, WeightedSqEuclidean(dist.weights), a, b))
159161
end
160-
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix, b::AbstractMatrix)
161-
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a, b))
162+
function _pairwise!(r::AbstractMatrix, dist::WeightedEuclidean,
163+
a::AbstractMatrix, b::AbstractMatrix)
164+
sqrt!(_pairwise!(r, WeightedSqEuclidean(dist.weights), a, b))
162165
end
163-
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
164-
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a))
166+
function _pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
167+
sqrt!(_pairwise!(r, WeightedSqEuclidean(dist.weights), a))
165168
end

test/test_dists.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,13 @@ function test_pairwise(dist, x, y, T)
441441
for j = 1:nx, i = 1:nx
442442
rxx[i, j] = evaluate(dist, x[:, i], x[:, j])
443443
end
444-
# ≈ and all( .≈ ) seem to behave slightly differently for F64
445-
# And, as earlier, we have small rounding errors in accumulations
446-
@test all(pairwise(dist, x, y) .+ one(T) .≈ rxy .+ one(T))
447-
@test all(pairwise(dist, x) .+ one(T) .≈ rxx .+ one(T))
444+
# As earlier, we have small rounding errors in accumulations
445+
@test pairwise(dist, x, y) rxy
446+
@test pairwise(dist, x) rxx
447+
@test pairwise(dist, x, y, dims=2) rxy
448+
@test pairwise(dist, x, dims=2) rxx
449+
@test pairwise(dist, permutedims(x), permutedims(y), dims=1) rxy
450+
@test pairwise(dist, permutedims(x), dims=1) rxx
448451
end
449452
end
450453

0 commit comments

Comments
 (0)