Skip to content

Commit 9b8358d

Browse files
committed
Fix for issue #10
KMeansClustering now accepts an optional equality function. This is useful when using numpy arrays as inputs.
1 parent c628d8b commit 9b8358d

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

cluster.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ class KMeansClustering:
680680
>>> clusters = cl.getclusters(2)
681681
"""
682682

683-
def __init__(self, data, distance=None):
683+
def __init__(self, data, distance=None, equality=None):
684684
"""
685685
Constructor
686686
@@ -690,11 +690,14 @@ def __init__(self, data, distance=None):
690690
Default: It assumes the tuples contain numeric values
691691
and appiles a generalised form of the
692692
euclidian-distance algorithm on them.
693+
equality - A function to test equality of items. By default the
694+
standard python equality operator (``==``) is applied.
693695
"""
694696
self.__clusters = []
695697
self.__data = data
696698
self.distance = distance
697699
self.__initial_length = len(data)
700+
self.equality = equality
698701

699702
# test if each item is of same dimensions
700703
if len(data) > 1 and isinstance(data[0], TupleType):
@@ -768,7 +771,7 @@ def assign_item(self, item, origin):
768771
centroid(closest_cluster)):
769772
closest_cluster = cluster
770773

771-
if closest_cluster != origin:
774+
if id(closest_cluster) != id(origin):
772775
self.move_item(item, origin, closest_cluster)
773776
return True
774777
else:
@@ -784,7 +787,16 @@ def move_item(self, item, origin, destination):
784787
origin - the originating cluster
785788
destination - the target cluster
786789
"""
787-
destination.append(origin.pop(origin.index(item)))
790+
if self.equality:
791+
item_index = 0
792+
for i, element in enumerate(origin):
793+
if self.equality(element, item):
794+
item_index = i
795+
break
796+
else:
797+
item_index = origin.index(item)
798+
799+
destination.append(origin.pop(item_index))
788800

789801
def initialise_clusters(self, input_, clustercount):
790802
"""

test.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,41 @@ def testLostFunctionReference(self):
202202
expected),
203203
"Elements differ!\n%s\n%s" % (clusters, expected))
204204

205+
def testMultidimArray(self):
206+
from random import random
207+
data = []
208+
for _ in range(200):
209+
data.append([random(), random()])
210+
cl = KMeansClustering(data, lambda p0, p1: (
211+
p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2)
212+
cl.getclusters(10)
213+
214+
215+
class NumpyTests(unittest.TestCase):
216+
217+
def testNumpyRandom(self):
218+
from cluster import KMeansClustering
219+
from numpy import random as rnd
220+
data = rnd.rand(500, 2)
221+
cl = KMeansClustering(data, lambda p0, p1: (
222+
p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2, numpy.array_equal)
223+
cl.getclusters(10)
224+
205225

206226
if __name__ == '__main__':
207-
unittest.TextTestRunner(verbosity=2).run(
208-
unittest.TestSuite((
209-
unittest.makeSuite(HClusterSmallListTestCase),
210-
unittest.makeSuite(HClusterIntegerTestCase),
211-
unittest.makeSuite(HClusterStringTestCase),
212-
unittest.makeSuite(KClusterSmallListTestCase),
213-
unittest.makeSuite(KCluster2DTestCase),
214-
unittest.makeSuite(KClusterSFBugs),
215-
))
216-
)
227+
suite = unittest.TestSuite((
228+
unittest.makeSuite(HClusterSmallListTestCase),
229+
unittest.makeSuite(HClusterIntegerTestCase),
230+
unittest.makeSuite(HClusterStringTestCase),
231+
unittest.makeSuite(KClusterSmallListTestCase),
232+
unittest.makeSuite(KCluster2DTestCase),
233+
unittest.makeSuite(KClusterSFBugs)))
234+
235+
try:
236+
import numpy # NOQA
237+
tests = unittest.makeSuite(NumpyTests)
238+
suite.addTests(tests)
239+
except ImportError:
240+
print "numpy not available. Associated test will not be loaded!"
241+
242+
unittest.TextTestRunner(verbosity=2).run(suite)

0 commit comments

Comments
 (0)