diff --git a/Project.toml b/Project.toml index df3bee8..a254908 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Distances" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.7" +version = "0.10.8" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/README.md b/README.md index cf4472c..7103ec2 100644 --- a/README.md +++ b/README.md @@ -132,14 +132,29 @@ If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax (`i` being either `1` or `2`): +```julia +colwise!(dist, r, X, Y) +pairwise!(dist, R, X, Y, dims=i) +pairwise!(dist, R, X, dims=i) +``` + +Please pay attention to the difference, the functions for inplace computation are +`colwise!` and `pairwise!` (instead of `colwise` and `pairwise`). + +#### Deprecated alternative syntax + +The syntax + ```julia colwise!(r, dist, X, Y) pairwise!(R, dist, X, Y, dims=i) pairwise!(R, dist, X, dims=i) ``` -Please pay attention to the difference, the functions for inplace computation are -`colwise!` and `pairwise!` (instead of `colwise` and `pairwise`). +with the first two arguments (metric and results) interchanged is supported as well. +However, its use is discouraged since +[it is deprecated](https://github.com/JuliaStats/Distances.jl/pull/239) and will be +removed in a future release. ## Distance type hierarchy diff --git a/src/Distances.jl b/src/Distances.jl index be04b78..51068a5 100644 --- a/src/Distances.jl +++ b/src/Distances.jl @@ -118,4 +118,6 @@ include("mahalanobis.jl") include("bhattacharyya.jl") include("bregman.jl") +include("deprecated.jl") + end # module end diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..504b27c --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,69 @@ +Base.@deprecate pairwise!(r::AbstractMatrix, dist::PreMetric, a) pairwise!(dist, r, a) +Base.@deprecate pairwise!(r::AbstractMatrix, dist::PreMetric, a, b) pairwise!(dist, r, a, b) + +Base.@deprecate pairwise!( + r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing +) pairwise!(dist, r, a; dims=dims) +Base.@deprecate pairwise!( + r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing +) pairwise!(dist, r, a, b; dims=dims) + +Base.@deprecate colwise!(r::AbstractArray, dist::PreMetric, a, b) colwise!(dist, r, a, b) + +# docstrings for deprecated methods +@doc """ + pairwise!(r::AbstractMatrix, dist::PreMetric, a) + +Same as `pairwise!(dist, r, a)`. + +!!! warning + Since this alternative syntax is deprecated and will be removed in a future release of + Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a)` instead. +""" pairwise!(r::AbstractMatrix, dist::PreMetric, a) +@doc """ + pairwise!(r::AbstractMatrix, dist::PreMetric, a, b) + +Same as `pairwise!(dist, r, a, b)`. + +!!! warning + Since this alternative syntax is deprecated and will be removed in a future release of + Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a, b)` instead. +""" pairwise!(r::AbstractMatrix, dist::PreMetric, a, b) + +@doc """ + pairwise!(r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix; dims) + +Same as `pairwise!(dist, r, a; dims)`. + + !!! warning + Since this alternative syntax is deprecated and will be removed in a future release of + Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a; dims)` instead. +""" pairwise!( + r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix; + dims::Union{Nothing,Integer} +) +@doc """ + pairwise!(r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix; dims) + +Same as `pairwise!(dist, r, a, b; dims)`. + +!!! warning + Since this alternative syntax is deprecated and will be removed in a future release of + Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a, b; dims)` + instead. +""" pairwise!( + r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix; + dims::Union{Nothing,Integer} +) + +@doc """ + colwise!(r::AbstractArray, dist::PreMetric, a, b) + +Same as `colwise!(dist, r, a, b)`. + +!!! warning + Since this alternative syntax is deprecated and will be removed in a future release of + Distances.jl, its use is discouraged. Please call `colwise!(dist, r, a, b)` instead. +""" colwise!(r::AbstractArray, dist::PreMetric, a, b) diff --git a/src/generic.jl b/src/generic.jl index 1903131..090bddb 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -46,7 +46,7 @@ __eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a))) # Generic column-wise evaluation """ - colwise!(r::AbstractArray, metric::PreMetric, a, b) + colwise!(metric::PreMetric, r::AbstractArray, a, b) Compute distances between corresponding elements of the iterable collections `a` and `b` according to distance `metric`, and store the result in `r`. @@ -54,7 +54,7 @@ Compute distances between corresponding elements of the iterable collections `a` and `b` must have the same number of elements, `r` must be an array of length `length(a) == length(b)`. """ -function colwise!(r::AbstractArray, metric::PreMetric, a, b) +function colwise!(metric::PreMetric, r::AbstractArray, a, b) require_one_based_indexing(r) n = length(a) length(b) == n || throw(DimensionMismatch("iterators have different lengths")) @@ -65,7 +65,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a, b) r end -function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::AbstractMatrix) +function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractVector, b::AbstractMatrix) require_one_based_indexing(r) n = size(b, 2) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @@ -75,7 +75,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs r end -function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractVector) +function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractVector) require_one_based_indexing(r) n = size(a, 2) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @@ -86,11 +86,11 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs end """ - colwise!(r::AbstractArray, metric::PreMetric, + colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix) - colwise!(r::AbstractArray, metric::PreMetric, + colwise!(metric::PreMetric, r::AbstractArray, a::AbstractVector, b::AbstractMatrix) - colwise!(r::AbstractArray, metric::PreMetric, + colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractVector) Compute distances between each corresponding columns of `a` and `b` according @@ -105,7 +105,7 @@ vector. `r` must be an array of length `maximum(size(a, 2), size(b, 2))`. If both `a` and `b` are vectors, the generic, iterator-based method of `colwise` applies. """ -function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix) +function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix) require_one_based_indexing(r, a, b) n = get_common_ncols(a, b) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @@ -126,7 +126,7 @@ Compute distances between corresponding elements of the iterable collections function colwise(metric::PreMetric, a, b) n = get_common_length(a, b) r = Vector{result_type(metric, a, b)}(undef, n) - colwise!(r, metric, a, b) + colwise!(metric, r, a, b) end """ @@ -148,25 +148,25 @@ vector. function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix) n = get_common_ncols(a, b) r = Vector{result_type(metric, a, b)}(undef, n) - colwise!(r, metric, a, b) + colwise!(metric, r, a, b) end function colwise(metric::PreMetric, a::AbstractVector, b::AbstractMatrix) n = size(b, 2) r = Vector{result_type(metric, a, b)}(undef, n) - colwise!(r, metric, a, b) + colwise!(metric, r, a, b) end function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractVector) n = size(a, 2) r = Vector{result_type(metric, a, b)}(undef, n) - colwise!(r, metric, a, b) + colwise!(metric, r, a, b) end # Generic pairwise evaluation -function _pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a) +function _pairwise!(metric::PreMetric, r::AbstractMatrix, a, b=a) require_one_based_indexing(r) na = length(a) nb = length(b) @@ -177,7 +177,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a) r end -function _pairwise!(r::AbstractMatrix, metric::PreMetric, +function _pairwise!(metric::PreMetric, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix=a) require_one_based_indexing(r, a, b) na = size(a, 2) @@ -192,7 +192,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric, r end -function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a) +function _pairwise!(metric::SemiMetric, r::AbstractMatrix, a) require_one_based_indexing(r) n = length(a) size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r.")) @@ -208,7 +208,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a) r end -function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix) +function _pairwise!(metric::SemiMetric, r::AbstractMatrix, a::AbstractMatrix) require_one_based_indexing(r) n = size(a, 2) size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r.")) @@ -237,7 +237,7 @@ function deprecated_dims(dims::Union{Nothing,Integer}) end """ - pairwise!(r::AbstractMatrix, metric::PreMetric, + pairwise!(metric::PreMetric, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix=a; dims) Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`) @@ -247,7 +247,7 @@ If a single matrix `a` is provided, compute distances between its rows or column `a` and `b` must have the same numbers of columns if `dims=1`, or of rows if `dims=2`. `r` must be a matrix with size `size(a, dims) × size(b, dims)`. """ -function pairwise!(r::AbstractMatrix, metric::PreMetric, +function pairwise!(metric::PreMetric, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix; dims::Union{Nothing,Integer}=nothing) dims = deprecated_dims(dims) @@ -266,13 +266,13 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb))).")) if dims == 1 - _pairwise!(r, metric, permutedims(a), permutedims(b)) + _pairwise!(metric, r, permutedims(a), permutedims(b)) else - _pairwise!(r, metric, a, b) + _pairwise!(metric, r, a, b) end end -function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix; +function pairwise!(metric::PreMetric, r::AbstractMatrix, a::AbstractMatrix; dims::Union{Nothing,Integer}=nothing) dims = deprecated_dims(dims) dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) @@ -284,14 +284,14 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix; size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n))).")) if dims == 1 - _pairwise!(r, metric, permutedims(a)) + _pairwise!(metric, r, permutedims(a)) else - _pairwise!(r, metric, a) + _pairwise!(metric, r, a) end end """ - pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a) + pairwise!(metric::PreMetric, r::AbstractMatrix, a, b=a) Compute distances between each element of collection `a` and each element of collection `b` according to distance `metric`, and store the result in `r`. @@ -299,8 +299,8 @@ If a single iterable `a` is provided, compute distances between its elements. `r` must be a matrix with size `length(a) × length(b)`. """ -pairwise!(r::AbstractMatrix, metric::PreMetric, a, b) = _pairwise!(r, metric, a, b) -pairwise!(r::AbstractMatrix, metric::PreMetric, a) = _pairwise!(r, metric, a) +pairwise!(metric::PreMetric, r::AbstractMatrix, a, b) = _pairwise!(metric, r, a, b) +pairwise!(metric::PreMetric, r::AbstractMatrix, a) = _pairwise!(metric, r, a) """ pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims) @@ -318,7 +318,7 @@ function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix; m = size(a, dims) n = size(b, dims) r = Matrix{result_type(metric, a, b)}(undef, m, n) - pairwise!(r, metric, a, b, dims=dims) + pairwise!(metric, r, a, b, dims=dims) end function pairwise(metric::PreMetric, a::AbstractMatrix; @@ -327,7 +327,7 @@ function pairwise(metric::PreMetric, a::AbstractMatrix; dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) n = size(a, dims) r = Matrix{result_type(metric, a, a)}(undef, n, n) - pairwise!(r, metric, a, dims=dims) + pairwise!(metric, r, a, dims=dims) end """ @@ -341,11 +341,11 @@ function pairwise(metric::PreMetric, a, b) m = length(a) n = length(b) r = Matrix{result_type(metric, a, b)}(undef, m, n) - _pairwise!(r, metric, a, b) + _pairwise!(metric, r, a, b) end function pairwise(metric::PreMetric, a) n = length(a) r = Matrix{result_type(metric, a, a)}(undef, n, n) - _pairwise!(r, metric, a) + _pairwise!(metric, r, a) end diff --git a/src/mahalanobis.jl b/src/mahalanobis.jl index dcfd128..ccfb57d 100644 --- a/src/mahalanobis.jl +++ b/src/mahalanobis.jl @@ -95,34 +95,34 @@ end sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b) mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b) -function _colwise!(r, dist, a, b) +function _colwise!(dist, r, a, b) Q = dist.qmat get_colwise_dims(size(Q, 1), r, a, b) z = a .- b dot_percol!(r, Q * z, z) end -function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix) - _colwise!(r, dist, a, b) +function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix) + _colwise!(dist, r, a, b) end -function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractVector, b::AbstractMatrix) - _colwise!(r, dist, a, b) +function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractVector, b::AbstractMatrix) + _colwise!(dist, r, a, b) end -function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractVector) - _colwise!(r, dist, a, b) +function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractVector) + _colwise!(dist, r, a, b) end -function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix) - sqrt!(_colwise!(r, dist, a, b)) +function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix) + sqrt!(_colwise!(dist, r, a, b)) end -function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractVector, b::AbstractMatrix) - sqrt!(_colwise!(r, dist, a, b)) +function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractVector, b::AbstractMatrix) + sqrt!(_colwise!(dist, r, a, b)) end -function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractVector) - sqrt!(_colwise!(r, dist, a, b)) +function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractVector) + sqrt!(_colwise!(dist, r, a, b)) end -function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix, b::AbstractMatrix) +function _pairwise!(dist::Union{SqMahalanobis,Mahalanobis}, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) Q = dist.qmat m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b) @@ -140,7 +140,7 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a r end -function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix) +function _pairwise!(dist::Union{SqMahalanobis,Mahalanobis}, r::AbstractMatrix, a::AbstractMatrix) Q = dist.qmat m, n = get_pairwise_dims(size(Q, 1), r, a) diff --git a/src/metrics.jl b/src/metrics.jl index ec4de63..9ee5ff7 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -659,7 +659,7 @@ const nrmsd = NormRMSDeviation() ########################################################### # SqEuclidean/Euclidean -function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, +function _pairwise!(dist::Union{SqEuclidean,Euclidean}, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) require_one_based_indexing(r, a, b) m, na, nb = get_pairwise_dims(r, a, b) @@ -700,7 +700,7 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, r end -function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, a::AbstractMatrix) +function _pairwise!(dist::Union{SqEuclidean,Euclidean}, r::AbstractMatrix, a::AbstractMatrix) require_one_based_indexing(r, a) m, n = get_pairwise_dims(r, a) inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(a)))) === eltype(r) @@ -737,7 +737,7 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, a::Ab end # Weighted SqEuclidean/Euclidean -function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedEuclidean}, +function _pairwise!(dist::Union{WeightedSqEuclidean,WeightedEuclidean}, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) require_one_based_indexing(r, a, b) w = dist.weights @@ -756,7 +756,7 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE end r end -function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedEuclidean}, +function _pairwise!(dist::Union{WeightedSqEuclidean,WeightedEuclidean}, r::AbstractMatrix, a::AbstractMatrix) require_one_based_indexing(r, a) w = dist.weights @@ -781,8 +781,8 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE end # MeanSqDeviation, RMSDeviation, NormRMSDeviation -function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix, b::AbstractMatrix) - _pairwise!(r, SqEuclidean(), a, b) +function _pairwise!(dist::MeanSqDeviation, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) + _pairwise!(SqEuclidean(), r, a, b) # TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2 s = size(a, 1) @simd for I in eachindex(r) @@ -790,10 +790,10 @@ function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix, end return r end -_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix, b::AbstractMatrix) = - sqrt!(_pairwise!(r, MeanSqDeviation(), a, b)) -function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix, b::AbstractMatrix) - _pairwise!(r, RMSDeviation(), a, b) +_pairwise!(dist::RMSDeviation, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) = + sqrt!(_pairwise!(MeanSqDeviation(), r, a, b)) +function _pairwise!(dist::NormRMSDeviation, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) + _pairwise!(RMSDeviation(), r, a, b) @views for (i, j) in zip(axes(r, 1), axes(a, 2)) amin, amax = extrema(a[:,j]) r[i,:] ./= amax - amin @@ -801,8 +801,8 @@ function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix return r end -function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix) - _pairwise!(r, SqEuclidean(), a) +function _pairwise!(dist::MeanSqDeviation, r::AbstractMatrix, a::AbstractMatrix) + _pairwise!(SqEuclidean(), r, a) # TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2 s = size(a, 1) @simd for I in eachindex(r) @@ -810,10 +810,10 @@ function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix) end return r end -_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix) = - sqrt!(_pairwise!(r, MeanSqDeviation(), a)) -function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix) - _pairwise!(r, RMSDeviation(), a) +_pairwise!(dist::RMSDeviation, r::AbstractMatrix, a::AbstractMatrix) = + sqrt!(_pairwise!(MeanSqDeviation(), r, a)) +function _pairwise!(dist::NormRMSDeviation, r::AbstractMatrix, a::AbstractMatrix) + _pairwise!(RMSDeviation(), r, a) @views for (i, j) in zip(axes(r, 1), axes(a, 2)) amin, amax = extrema(a[:,j]) r[i,:] ./= amax - amin @@ -823,7 +823,7 @@ end # CosineDist -function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix, b::AbstractMatrix) +function _pairwise!(::CosineDist, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) require_one_based_indexing(r, a, b) m, na, nb = get_pairwise_dims(r, a, b) inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(b)))) === eltype(r) @@ -837,7 +837,7 @@ function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix, b::Abstr end r end -function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix) +function _pairwise!(::CosineDist, r::AbstractMatrix, a::AbstractMatrix) require_one_based_indexing(r, a) m, n = get_pairwise_dims(r, a) inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(a)))) === eltype(r) @@ -861,7 +861,7 @@ end # 2. pre-calculated `_centralize_colwise` avoids four times of redundant computations # of `_centralize` -- ~4x speed up _centralize_colwise(x::AbstractMatrix) = x .- mean(x, dims=1) -_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix, b::AbstractMatrix) = - _pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b)) -_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix) = - _pairwise!(r, CosineDist(), _centralize_colwise(a)) +_pairwise!(::CorrDist, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix) = + _pairwise!(CosineDist(), r, _centralize_colwise(a), _centralize_colwise(b)) +_pairwise!(::CorrDist, r::AbstractMatrix, a::AbstractMatrix) = + _pairwise!(CosineDist(), r, _centralize_colwise(a)) diff --git a/test/test_dists.jl b/test/test_dists.jl index dfe66d2..47941c3 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -388,10 +388,10 @@ end # testset @test_throws DimensionMismatch sqmahalanobis(q, q, Q) mat23 = [0.3 0.2 0.0; 0.1 0.0 0.4] mat22 = [0.3 0.2; 0.1 0.4] - @test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat23) - @test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, q) - @test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat22) - @test_throws DimensionMismatch colwise!(mat23, Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat22) + @test_throws DimensionMismatch colwise!(Euclidean(), mat23, mat23, mat23) + @test_throws DimensionMismatch colwise!(Euclidean(), mat23, mat23, q) + @test_throws DimensionMismatch colwise!(Euclidean(), mat23, mat23, mat22) + @test_throws DimensionMismatch colwise!(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat23, mat22) @test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x)([1, 2, 3], [1, 2]) @test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2])([1, 2, 3], [1, 2, 3]) sv1 = sprand(10, .2) @@ -559,6 +559,7 @@ function test_colwise(dist, x, y, T) r2 = zeros(T, n) r3 = zeros(T, n) r4 = zeros(T, 1, n) + r5 = zeros(T, 1, n) for j = 1:n r1[j] = dist(x[:, j], y[:, j]) r2[j] = dist(x[:, 1], y[:, j]) @@ -567,7 +568,10 @@ function test_colwise(dist, x, y, T) # ≈ and all( .≈ ) seem to behave slightly differently for F64 @test all(colwise(dist, x, y) .≈ r1) @test all(colwise(dist, (x[:,i] for i in axes(x, 2)), (y[:,i] for i in axes(y, 2))) .≈ r1) - colwise!(r4, dist, x, y) + + @test colwise!(dist, r4, x, y) ≈ @test_deprecated(colwise!(r5, dist, x, y)) + @test r4 ≈ r5 + @test all(r4[i] ≈ r1[i] for i in 1:n) @test all(colwise(dist, x[:, 1], y) .≈ r2) @test all(colwise(dist, x, y[:, 1]) .≈ r3) @@ -668,13 +672,56 @@ function test_pairwise(dist, x, y, T) @test pairwise(dist, x, dims=2) ≈ rxx @test pairwise(dist, permutedims(x), permutedims(y), dims=1) ≈ rxy @test pairwise(dist, permutedims(x), dims=1) ≈ rxx + + # In-place computations + rxy2 = zeros(T, nx, ny) + @test @test_deprecated(pairwise!(rxy2, dist, x, y; dims=2)) ≈ rxy + @test rxy2 ≈ rxy + fill!(rxy2, zero(T)) + @test pairwise!(dist, rxy2, x, y; dims=2) ≈ rxy + @test rxy2 ≈ rxy + + rxx2 = zeros(T, nx, nx) + @test @test_deprecated(pairwise!(rxx2, dist, x; dims=2)) ≈ rxx + @test rxx2 ≈ rxx + fill!(rxx2, zero(T)) + @test pairwise!(dist, rxx2, x; dims=2) ≈ rxx + @test rxx2 ≈ rxx + + fill!(rxy2, zero(T)) + @test @test_deprecated(pairwise!(rxy2, dist, permutedims(x), permutedims(y); dims=1)) ≈ rxy + @test rxy2 ≈ rxy + fill!(rxy2, zero(T)) + @test pairwise!(dist, rxy2, permutedims(x), permutedims(y); dims=1) ≈ rxy + @test rxy2 ≈ rxy + + fill!(rxx2, zero(T)) + @test @test_deprecated(pairwise!(rxx2, dist, permutedims(x); dims=1)) ≈ rxx + @test rxx2 ≈ rxx + fill!(rxx2, zero(T)) + @test pairwise!(dist, rxx2, permutedims(x); dims=1) ≈ rxx + @test rxx2 ≈ rxx + + # General arguments (iterators and vectors of vectors) vecx = (x[:, i] for i in 1:nx) vecy = (y[:, i] for i in 1:ny) for (vecx, vecy) in ((vecx, vecy), (collect(vecx), collect(vecy))) @test pairwise(dist, vecx, vecy) ≈ rxy @test pairwise(dist, vecx) ≈ rxx - @test pairwise!(similar(rxy), dist, vecx, vecy) ≈ rxy - @test pairwise!(similar(rxx), dist, vecx) ≈ rxx + + fill!(rxy2, zero(T)) + @test @test_deprecated(pairwise!(rxy2, dist, vecx, vecy)) ≈ rxy + @test rxy2 ≈ rxy + fill!(rxy2, zero(T)) + @test pairwise!(dist, rxy2, vecx, vecy) ≈ rxy + @test rxy2 ≈ rxy + + fill!(rxx2, zero(T)) + @test @test_deprecated(pairwise!(rxx2, dist, vecx)) ≈ rxx + @test rxx2 ≈ rxx + fill!(rxx2, zero(T)) + @test pairwise!(dist, rxx2, vecx) ≈ rxx + @test rxx2 ≈ rxx end end end @@ -798,6 +845,37 @@ function test_scalar_pairwise(dist, x, y, T) @test pairwise(dist, permutedims(x), permutedims(y), dims=2) ≈ rxy @test pairwise(dist, permutedims(x), dims=2) ≈ rxx @test_throws DimensionMismatch pairwise(dist, permutedims(x), permutedims(y), dims=1) + + # In-place computations + rxy2 = similar(rxy) + fill!(rxy2, zero(eltype(rxy2))) + @test @test_deprecated(pairwise!(rxy2, dist, x, y)) ≈ rxy + @test rxy2 ≈ rxy + fill!(rxy2, zero(eltype(rxy2))) + @test pairwise!(dist, rxy2, x, y) ≈ rxy + @test rxy2 ≈ rxy + + rxx2 = similar(rxx) + fill!(rxx2, zero(eltype(rxx2))) + @test @test_deprecated(pairwise!(rxx2, dist, x)) ≈ rxx + @test rxx2 ≈ rxx + fill!(rxx2, zero(eltype(rxx2))) + @test pairwise!(dist, rxx2, x) ≈ rxx + @test rxx2 ≈ rxx + + fill!(rxy2, zero(eltype(rxy2))) + @test @test_deprecated(pairwise!(rxy2, dist, permutedims(x), permutedims(y); dims=2)) ≈ rxy + @test rxy2 ≈ rxy + fill!(rxy2, zero(eltype(rxy2))) + @test pairwise!(dist, rxy2, permutedims(x), permutedims(y); dims=2) ≈ rxy + @test rxy2 ≈ rxy + + fill!(rxx2, zero(eltype(rxx2))) + @test @test_deprecated(pairwise!(rxx2, dist, permutedims(x); dims=2)) ≈ rxx + @test rxx2 ≈ rxx + fill!(rxx2, zero(eltype(rxx2))) + @test pairwise!(dist, rxx2, permutedims(x); dims=2) ≈ rxx + @test rxx2 ≈ rxx end end @@ -972,13 +1050,13 @@ end a = rand(2, 41) b = rand(2, 41) z = zeros(41) - colwise!(z, d, a, b) + colwise!(d, z, a, b) # This fails when bounds checking is enforced bounds = Base.JLOptions().check_bounds if bounds == 0 - @test (@allocated colwise!(z, d, a, b)) == 0 + @test (@allocated colwise!(d, z, a, b)) == 0 else - @test_broken (@allocated colwise!(z, d, a, b)) == 0 + @test_broken (@allocated colwise!(d, z, a, b)) == 0 end end =#