Skip to content

Commit 7bae312

Browse files
authored
Merge pull request #23 from JuliaAI/adjoint-not-transpose
Use adjoints not transposes
2 parents 1a2c1e7 + 8fd563a commit 7bae312

File tree

4 files changed

+9
-16
lines changed

4 files changed

+9
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJClusteringInterface"
22
uuid = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
33
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"

src/MLJClusteringInterface.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ const PKG = "MLJClusteringInterface"
3535
end
3636

3737
function MMI.fit(model::KMeans, verbosity::Int, X)
38-
# NOTE: using transpose here to get a LinearAlgebra.Transpose object
39-
# which Kmeans can handle.
40-
Xarray = transpose(MMI.matrix(X))
38+
Xarray = MMI.matrix(X)'
4139
result = Cl.kmeans(Xarray, model.k; distance=model.metric, init=model.init)
4240
cluster_labels = MMI.categorical(1:model.k)
4341
fitresult = (result.centers, cluster_labels) # centers (p x k)
@@ -55,7 +53,7 @@ function MMI.transform(model::KMeans, fitresult, X)
5553
# pairwise distance from samples to centers
5654
= pairwise(
5755
model.metric,
58-
transpose(MMI.matrix(X)),
56+
MMI.matrix(X)',
5957
fitresult[1],
6058
dims=2
6159
)
@@ -71,9 +69,7 @@ end
7169
end
7270

7371
function MMI.fit(model::KMedoids, verbosity::Int, X)
74-
# NOTE: using transpose=true will materialize the transpose (~ permutedims), KMedoids
75-
# does not yet accept LinearAlgebra.Transpose
76-
Xarray = MMI.matrix(X, transpose=true)
72+
Xarray = MMI.matrix(X)'
7773
# cost matrix: all the pairwise distances
7874
cost_array = pairwise(model.metric, Xarray, dims=2) # n x n
7975
result = Cl.kmedoids(cost_array, model.k, init = model.init)
@@ -93,7 +89,7 @@ function MMI.transform(model::KMedoids, fitresult, X)
9389
# pairwise distance from samples to medoids
9490
= pairwise(
9591
model.metric,
96-
MMI.matrix(X, transpose=true),
92+
MMI.matrix(X)',
9793
fitresult[1], dims=2
9894
)
9995
return MMI.table(X̃, prototype=X)

test/Project.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
33
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
44
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
55
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
6-
MLJTestIntegration = "697918b4-fdc1-4f9e-8ff9-929724cee270"
6+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9-
10-
[compat]
11-
MLJTestIntegration = "0.2.2"

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Distances
33
import LinearAlgebra: norm
44

55
using MLJBase
6-
using MLJTestIntegration
6+
using MLJTestInterface
77
using MLJClusteringInterface
88
using Random: seed!
99
using Test
@@ -122,11 +122,11 @@ end
122122

123123
@testset "MLJ interface" begin
124124
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
125-
failures, summary = MLJTestIntegration.test(
125+
failures, summary = MLJTestInterface.test(
126126
models,
127127
X;
128128
mod=@__MODULE__,
129-
verbosity=0,
129+
verbosity=0, # bump to debug
130130
throw=false, # set to true to debug
131131
)
132132
@test isempty(failures)

0 commit comments

Comments
 (0)