Skip to content

Commit d18762c

Browse files
authored
Add pairwise (#627)
This generic method takes iterators of vectors and supports skipping missing values. It is a more general version of `pairwise` in Distances.jl. Since methods are compatible, both packages can override a common empty function defined in StatsAPI.
1 parent 45d65ec commit d18762c

File tree

6 files changed

+583
-0
lines changed

6 files changed

+583
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
1617

1718
[compat]
1819
DataAPI = "1"
1920
DataStructures = "0.10, 0.11, 0.12, 0.13, 0.14, 0.17, 0.18"
2021
Missings = "0.3, 0.4, 1.0"
2122
SortingAlgorithms = "0.3, 1.0"
23+
StatsAPI = "1"
2224
julia = "1"
2325

2426
[extras]

docs/src/misc.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ levelsmap
77
indexmap
88
indicatormat
99
StatsBase.midpoints
10+
pairwise
11+
pairwise!
1012
```

src/StatsBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import LinearAlgebra: BlasReal, BlasFloat
1919
import Statistics: mean, mean!, var, varm, varm!, std, stdm, cov, covm,
2020
cor, corm, cov2cor!, unscaled_covzm, quantile, sqrt!,
2121
median, middle
22+
import StatsAPI: pairwise, pairwise!
2223

2324
## tackle compatibility issues
2425

@@ -157,6 +158,8 @@ export
157158
indexmap, # construct a map from element to index
158159
levelsmap, # construct a map from n unique elements to [1, ..., n]
159160
indicatormat, # construct indicator matrix
161+
pairwise, # pairwise application of functions
162+
pairwise!, # pairwise! application of functions
160163

161164
# statistical models
162165
CoefTable,
@@ -228,6 +231,7 @@ include("signalcorr.jl")
228231
include("partialcor.jl")
229232
include("empirical.jl")
230233
include("hist.jl")
234+
include("pairwise.jl")
231235
include("misc.jl")
232236

233237
include("sampling.jl")

src/pairwise.jl

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
function _pairwise!(::Val{:none}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
2+
@inbounds for (i, xi) in enumerate(x), (j, yj) in enumerate(y)
3+
symmetric && i > j && continue
4+
5+
# For performance, diagonal is special-cased
6+
if f === cor && eltype(dest) !== Union{} && i == j && xi === yj
7+
# TODO: float() will not be needed after JuliaLang/Statistics.jl#61
8+
dest[i, j] = float(cor(xi))
9+
else
10+
dest[i, j] = f(xi, yj)
11+
end
12+
end
13+
if symmetric
14+
m, n = size(dest)
15+
@inbounds for j in 1:n, i in (j+1):m
16+
dest[i, j] = dest[j, i]
17+
end
18+
end
19+
return dest
20+
end
21+
22+
function check_vectors(x, y, skipmissing::Symbol)
23+
m = length(x)
24+
n = length(y)
25+
if !(all(xi -> xi isa AbstractVector, x) && all(yi -> yi isa AbstractVector, y))
26+
throw(ArgumentError("All entries in x and y must be vectors " *
27+
"when skipmissing=:$skipmissing"))
28+
end
29+
if m > 1
30+
indsx = keys(first(x))
31+
for i in 2:m
32+
keys(x[i]) == indsx ||
33+
throw(ArgumentError("All input vectors must have the same indices"))
34+
end
35+
end
36+
if n > 1
37+
indsy = keys(first(y))
38+
for j in 2:n
39+
keys(y[j]) == indsy ||
40+
throw(ArgumentError("All input vectors must have the same indices"))
41+
end
42+
end
43+
if m > 1 && n > 1
44+
indsx == indsy ||
45+
throw(ArgumentError("All input vectors must have the same indices"))
46+
end
47+
end
48+
49+
function _pairwise!(::Val{:pairwise}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
50+
check_vectors(x, y, :pairwise)
51+
@inbounds for (j, yj) in enumerate(y)
52+
ynminds = .!ismissing.(yj)
53+
@inbounds for (i, xi) in enumerate(x)
54+
symmetric && i > j && continue
55+
56+
if xi === yj
57+
ynm = view(yj, ynminds)
58+
# For performance, diagonal is special-cased
59+
if f === cor && eltype(dest) !== Union{} && i == j
60+
# TODO: float() will not be needed after JuliaLang/Statistics.jl#61
61+
dest[i, j] = float(cor(xi))
62+
else
63+
dest[i, j] = f(ynm, ynm)
64+
end
65+
else
66+
nminds = .!ismissing.(xi) .& ynminds
67+
xnm = view(xi, nminds)
68+
ynm = view(yj, nminds)
69+
dest[i, j] = f(xnm, ynm)
70+
end
71+
end
72+
end
73+
if symmetric
74+
m, n = size(dest)
75+
@inbounds for j in 1:n, i in (j+1):m
76+
dest[i, j] = dest[j, i]
77+
end
78+
end
79+
return dest
80+
end
81+
82+
function _pairwise!(::Val{:listwise}, f, dest::AbstractMatrix, x, y, symmetric::Bool)
83+
check_vectors(x, y, :listwise)
84+
m, n = size(dest)
85+
nminds = .!ismissing.(first(x))
86+
@inbounds for xi in Iterators.drop(x, 1)
87+
nminds .&= .!ismissing.(xi)
88+
end
89+
if x !== y
90+
@inbounds for yj in y
91+
nminds .&= .!ismissing.(yj)
92+
end
93+
end
94+
95+
# Computing integer indices once for all vectors is faster
96+
nminds′ = findall(nminds)
97+
# TODO: check whether wrapping views in a custom array type which asserts
98+
# that entries cannot be `missing` (similar to `skipmissing`)
99+
# could offer better performance
100+
return _pairwise!(Val(:none), f, dest,
101+
[view(xi, nminds′) for xi in x],
102+
[view(yi, nminds′) for yi in y],
103+
symmetric)
104+
end
105+
106+
function _pairwise!(f, dest::AbstractMatrix, x, y;
107+
symmetric::Bool=false, skipmissing::Symbol=:none)
108+
if !(skipmissing in (:none, :pairwise, :listwise))
109+
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise"))
110+
end
111+
112+
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
113+
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
114+
m = length(x′)
115+
n = length(y′)
116+
117+
size(dest) != (m, n) &&
118+
throw(DimensionMismatch("dest has dimensions $(size(dest)) but expected ($m, $n)"))
119+
120+
Base.has_offset_axes(dest) && throw("dest indices must start at 1")
121+
122+
return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric)
123+
end
124+
125+
function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing}
126+
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
127+
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
128+
m = length(x′)
129+
n = length(y′)
130+
131+
T = Core.Compiler.return_type(f, Tuple{eltype(x′), eltype(y′)})
132+
Tsm = Core.Compiler.return_type((x, y) -> f(disallowmissing(x), disallowmissing(y)),
133+
Tuple{eltype(x′), eltype(y′)})
134+
135+
if skipmissing === :none
136+
dest = Matrix{T}(undef, m, n)
137+
elseif skipmissing in (:pairwise, :listwise)
138+
dest = Matrix{Tsm}(undef, m, n)
139+
else
140+
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise"))
141+
end
142+
143+
# Preserve inferred element type
144+
isempty(dest) && return dest
145+
146+
_pairwise!(f, dest, x′, y′, symmetric=symmetric, skipmissing=skipmissing)
147+
148+
if isconcretetype(eltype(dest))
149+
return dest
150+
else
151+
# Final eltype depends on actual contents (consistent with map and broadcast)
152+
U = mapreduce(typeof, promote_type, dest)
153+
# V is inferred (contrary to U), but it only gives an upper bound for U
154+
V = promote_type(T, Tsm)
155+
return convert(Matrix{U}, dest)::Matrix{<:V}
156+
end
157+
end
158+
159+
"""
160+
pairwise!(f, dest::AbstractMatrix, x[, y];
161+
symmetric::Bool=false, skipmissing::Symbol=:none)
162+
163+
Store in matrix `dest` the result of applying `f` to all possible pairs
164+
of entries in iterators `x` and `y`, and return it. Rows correspond to
165+
entries in `x` and columns to entries in `y`, and `dest` must therefore
166+
be of size `length(x) × length(y)`.
167+
If `y` is omitted then `x` is crossed with itself.
168+
169+
As a special case, if `f` is `cor`, diagonal cells for which entries
170+
from `x` and `y` are identical (according to `===`) are set to one even
171+
in the presence `missing`, `NaN` or `Inf` entries.
172+
173+
# Keyword arguments
174+
- `symmetric::Bool=false`: If `true`, `f` is only called to compute
175+
for the lower triangle of the matrix, and these values are copied
176+
to fill the upper triangle. Only allowed when `y` is omitted.
177+
Defaults to `true` when `f` is `cor` or `cov`.
178+
- `skipmissing::Symbol=:none`: If `:none` (the default), missing values
179+
in inputs are passed to `f` without any modification.
180+
Use `:pairwise` to skip entries with a `missing` value in either
181+
of the two vectors passed to `f` for a given pair of vectors in `x` and `y`.
182+
Use `:listwise` to skip entries with a `missing` value in any of the
183+
vectors in `x` or `y`; note that this might drop a large part of entries.
184+
Only allowed when entries in `x` and `y` are vectors.
185+
186+
# Examples
187+
```jldoctest
188+
julia> using StatsBase, Statistics
189+
190+
julia> dest = zeros(3, 3);
191+
192+
julia> x = [1 3 7
193+
2 5 6
194+
3 8 4
195+
4 6 2];
196+
197+
julia> pairwise!(cor, dest, eachcol(x));
198+
199+
julia> dest
200+
3×3 Matrix{Float64}:
201+
1.0 0.744208 -0.989778
202+
0.744208 1.0 -0.68605
203+
-0.989778 -0.68605 1.0
204+
205+
julia> y = [1 3 missing
206+
2 5 6
207+
3 missing 2
208+
4 6 2];
209+
210+
julia> pairwise!(cor, dest, eachcol(y), skipmissing=:pairwise);
211+
212+
julia> dest
213+
3×3 Matrix{Float64}:
214+
1.0 0.928571 -0.866025
215+
0.928571 1.0 -1.0
216+
-0.866025 -1.0 1.0
217+
```
218+
"""
219+
function pairwise!(f, dest::AbstractMatrix, x, y=x;
220+
symmetric::Bool=false, skipmissing::Symbol=:none)
221+
if symmetric && x !== y
222+
throw(ArgumentError("symmetric=true only makes sense passing " *
223+
"a single set of variables (x === y)"))
224+
end
225+
226+
return _pairwise!(f, dest, x, y, symmetric=symmetric, skipmissing=skipmissing)
227+
end
228+
229+
"""
230+
pairwise(f, x[, y];
231+
symmetric::Bool=false, skipmissing::Symbol=:none)
232+
233+
Return a matrix holding the result of applying `f` to all possible pairs
234+
of entries in iterators `x` and `y`. Rows correspond to
235+
entries in `x` and columns to entries in `y`. If `y` is omitted then a
236+
square matrix crossing `x` with itself is returned.
237+
238+
As a special case, if `f` is `cor`, diagonal cells for which entries
239+
from `x` and `y` are identical (according to `===`) are set to one even
240+
in the presence `missing`, `NaN` or `Inf` entries.
241+
242+
# Keyword arguments
243+
- `symmetric::Bool=false`: If `true`, `f` is only called to compute
244+
for the lower triangle of the matrix, and these values are copied
245+
to fill the upper triangle. Only allowed when `y` is omitted.
246+
Defaults to `true` when `f` is `cor` or `cov`.
247+
- `skipmissing::Symbol=:none`: If `:none` (the default), missing values
248+
in inputs are passed to `f` without any modification.
249+
Use `:pairwise` to skip entries with a `missing` value in either
250+
of the two vectors passed to `f` for a given pair of vectors in `x` and `y`.
251+
Use `:listwise` to skip entries with a `missing` value in any of the
252+
vectors in `x` or `y`; note that this might drop a large part of entries.
253+
Only allowed when entries in `x` and `y` are vectors.
254+
255+
# Examples
256+
```jldoctest
257+
julia> using StatsBase, Statistics
258+
259+
julia> x = [1 3 7
260+
2 5 6
261+
3 8 4
262+
4 6 2];
263+
264+
julia> pairwise(cor, eachcol(x))
265+
3×3 Matrix{Float64}:
266+
1.0 0.744208 -0.989778
267+
0.744208 1.0 -0.68605
268+
-0.989778 -0.68605 1.0
269+
270+
julia> y = [1 3 missing
271+
2 5 6
272+
3 missing 2
273+
4 6 2];
274+
275+
julia> pairwise(cor, eachcol(y), skipmissing=:pairwise)
276+
3×3 Matrix{Float64}:
277+
1.0 0.928571 -0.866025
278+
0.928571 1.0 -1.0
279+
-0.866025 -1.0 1.0
280+
```
281+
"""
282+
function pairwise(f, x, y=x; symmetric::Bool=false, skipmissing::Symbol=:none)
283+
if symmetric && x !== y
284+
throw(ArgumentError("symmetric=true only makes sense passing " *
285+
"a single set of variables (x === y)"))
286+
end
287+
288+
return _pairwise(Val(skipmissing), f, x, y, symmetric)
289+
end
290+
291+
# cov(x) is faster than cov(x, x)
292+
_cov(x, y) = x === y ? cov(x) : cov(x, y)
293+
294+
pairwise!(::typeof(cov), dest::AbstractMatrix, x, y;
295+
symmetric::Bool=false, skipmissing::Symbol=:none) =
296+
pairwise!(_cov, dest, x, y, symmetric=symmetric, skipmissing=skipmissing)
297+
298+
pairwise(::typeof(cov), x, y; symmetric::Bool=false, skipmissing::Symbol=:none) =
299+
pairwise(_cov, x, y, symmetric=symmetric, skipmissing=skipmissing)
300+
301+
pairwise!(::typeof(cov), dest::AbstractMatrix, x;
302+
symmetric::Bool=true, skipmissing::Symbol=:none) =
303+
pairwise!(_cov, dest, x, x, symmetric=symmetric, skipmissing=skipmissing)
304+
305+
pairwise(::typeof(cov), x; symmetric::Bool=true, skipmissing::Symbol=:none) =
306+
pairwise(_cov, x, x, symmetric=symmetric, skipmissing=skipmissing)
307+
308+
pairwise!(::typeof(cor), dest::AbstractMatrix, x;
309+
symmetric::Bool=true, skipmissing::Symbol=:none) =
310+
pairwise!(cor, dest, x, x, symmetric=symmetric, skipmissing=skipmissing)
311+
312+
pairwise(::typeof(cor), x; symmetric::Bool=true, skipmissing::Symbol=:none) =
313+
pairwise(cor, x, x, symmetric=symmetric, skipmissing=skipmissing)

0 commit comments

Comments
 (0)