Skip to content

Commit 1d0a66a

Browse files
committed
fix custom distance kmeans test
1 parent 26d3acb commit 1d0a66a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

test/kmeans.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@ using Distances
44
using Random
55
using LinearAlgebra
66

7-
import Distances.pairwise!
8-
97
# custom distance metric
108
struct MySqEuclidean <: SemiMetric end
119

1210
# redefinition of Distances.pairwise! for MySqEuclidean type
13-
function pairwise!(r::AbstractMatrix, dist::MySqEuclidean,
14-
a::AbstractMatrix, b::AbstractMatrix; dims::Integer=2)
11+
function Distances.pairwise!(r::AbstractMatrix, dist::MySqEuclidean,
12+
a::AbstractMatrix, b::AbstractMatrix; dims::Integer=2)
1513
dims == 2 || throw(ArgumentError("only dims=2 supported for MySqEuclidean distance"))
1614
mul!(r, transpose(a), b)
1715
sa2 = sum(abs2, a, dims=1)
1816
sb2 = sum(abs2, b, dims=1)
1917
@inbounds r .= sa2' .+ sb2 .- 2r
2018
end
2119

20+
Distances.result_type(::MySqEuclidean, ::Type{T}, ::Type{T}) where T <: Number = T
21+
22+
(dist::MySqEuclidean)(a::AbstractMatrix, b::AbstractMatrix) =
23+
pairwise!(similar(a, size(a, 2), size(b, 2)))
24+
2225
@testset "kmeans() (k-means)" begin
2326

2427
@testset "Argument checks" begin

0 commit comments

Comments
 (0)