Skip to content

Commit db5b20c

Browse files
committed
Linkage funcs extracted and cached.
This only works for input data with hashable values for now. This gives a tremendous speed boost at the cost of memory. The memory cost can however be improved by removing "clusters inside other clusters".
1 parent d6ebdb2 commit db5b20c

File tree

2 files changed

+112
-120
lines changed

2 files changed

+112
-120
lines changed

cluster/linkage.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from __future__ import division
2+
from functools import wraps
3+
4+
5+
def cached(fun):
6+
"""
7+
memoizing decorator for linkage functions.
8+
9+
Parameters have been hardcoded (no ``*args``, ``**kwargs`` magic), because,
10+
the way this is coded (interchangingly using sets and frozensets) is true
11+
for this specific case. For other cases that is not necessarily guaranteed.
12+
"""
13+
return fun
14+
15+
_cache = {}
16+
17+
@wraps(fun)
18+
def newfun(a, b, distance_function):
19+
frozen_a = frozenset(a)
20+
frozen_b = frozenset(b)
21+
if (frozen_a, frozen_b) not in _cache:
22+
result = fun(a, b, distance_function)
23+
_cache[(frozen_a, frozen_b)] = result
24+
return _cache[(frozen_a, frozen_b)]
25+
return newfun
26+
27+
28+
@cached
29+
def single(a, b, distance_function):
30+
"""
31+
Given two collections ``a`` and ``b``, this will return the distance of the
32+
points which are closest togetger. ``distance_function`` is used to
33+
determine the distance between two elements.
34+
35+
Example::
36+
37+
>>> single([1, 2], [3, 4], lambda x, y: abs(x-y))
38+
1 # (distance between 2 and 3)
39+
"""
40+
left_a, right_a = min(a), max(a)
41+
left_b, right_b = min(b), max(b)
42+
result = min(distance_function(left_a, right_b),
43+
distance_function(left_b, right_a))
44+
return result
45+
46+
47+
@cached
48+
def complete(a, b, distance_function):
49+
"""
50+
Given two collections ``a`` and ``b``, this will return the distance of the
51+
points which are farthest apart. ``distance_function`` is used to determine
52+
the distance between two elements.
53+
54+
Example::
55+
56+
>>> single([1, 2], [3, 4], lambda x, y: abs(x-y))
57+
3 # (distance between 1 and 4)
58+
"""
59+
left_a, right_a = min(a), max(a)
60+
left_b, right_b = min(b), max(b)
61+
result = max(distance_function(left_a, right_b),
62+
distance_function(left_b, right_a))
63+
return result
64+
65+
66+
@cached
67+
def average(a, b, distance_function):
68+
"""
69+
Given two collections ``a`` and ``b``, this will return the mean of all
70+
distances. ``distance_function`` is used to determine the distance between
71+
two elements.
72+
73+
Example::
74+
75+
>>> single([1, 2], [3, 100], lambda x, y: abs(x-y))
76+
26
77+
"""
78+
distances = [distance_function(x, y)
79+
for x in a for y in b]
80+
return sum(distances) / len(distances)
81+
82+
83+
@cached
84+
def uclus(a, b, distance_function):
85+
"""
86+
Given two collections ``a`` and ``b``, this will return the *median* of all
87+
distances. ``distance_function`` is used to determine the distance between
88+
two elements.
89+
90+
Example::
91+
92+
>>> single([1, 2], [3, 100], lambda x, y: abs(x-y))
93+
2.5
94+
"""
95+
distances = sorted([distance_function(x, y)
96+
for x in a for y in b])
97+
midpoint, rest = len(distances) // 2, len(distances) % 2
98+
if not rest:
99+
return sum(distances[midpoint-1:midpoint+1]) / 2
100+
else:
101+
return distances[midpoint]

cluster/method/hierarchical.py

Lines changed: 11 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
1616
#
1717

18+
from functools import partial
1819
import logging
1920

2021
from cluster.cluster import Cluster
2122
from cluster.matrix import Matrix
2223
from cluster.method.base import BaseClusterMethod
2324
from cluster.util import median, mean, fullyflatten
25+
from cluster.linkage import single, complete, average, uclus
2426

2527

2628
logger = logging.getLogger(__name__)
@@ -58,7 +60,7 @@ class HierarchicalClustering(BaseClusterMethod):
5860

5961
def __init__(self, data, distance_function, linkage=None, num_processes=1):
6062
if not linkage:
61-
linkage = 'single'
63+
linkage = single
6264
logger.info("Initializing HierarchicalClustering object with linkage "
6365
"method %s", linkage)
6466
BaseClusterMethod.__init__(self, sorted(data), distance_function)
@@ -74,131 +76,19 @@ def set_linkage_method(self, method):
7476
``'single'``, ``'complete'``, ``'average'`` or ``'uclus'``.
7577
"""
7678
if method == 'single':
77-
self.linkage = self.single_linkage_distance
79+
self.linkage = single
7880
elif method == 'complete':
79-
self.linkage = self.complete_linkage_distance
81+
self.linkage = complete
8082
elif method == 'average':
81-
self.linkage = self.average_linkage_distance
83+
self.linkage = average
8284
elif method == 'uclus':
83-
self.linkage = self.uclus_distance
85+
self.linkage = uclus
86+
elif hasattr(method, '__call__'):
87+
self.linkage = method
8488
else:
8589
raise ValueError('distance method must be one of single, '
8690
'complete, average of uclus')
8791

88-
def uclus_distance(self, x, y):
89-
"""
90-
The method to determine the distance between one cluster an another
91-
item/cluster. The distance equals to the *average* (median) distance
92-
from any member of one cluster to any member of the other cluster.
93-
94-
:param x: first cluster/item.
95-
:param y: second cluster/item.
96-
"""
97-
# create a flat list of all the items in <x>
98-
if not isinstance(x, Cluster):
99-
x = [x]
100-
else:
101-
x = fullyflatten(x.items)
102-
103-
# create a flat list of all the items in <y>
104-
if not isinstance(y, Cluster):
105-
y = [y]
106-
else:
107-
y = fullyflatten(y.items)
108-
109-
distances = []
110-
for k in x:
111-
for l in y:
112-
distances.append(self.distance(k, l))
113-
return median(distances)
114-
115-
def average_linkage_distance(self, x, y):
116-
"""
117-
The method to determine the distance between one cluster an another
118-
item/cluster. The distance equals to the *average* (mean) distance
119-
from any member of one cluster to any member of the other cluster.
120-
121-
:param x: first cluster/item.
122-
:param y: second cluster/item.
123-
"""
124-
# create a flat list of all the items in <x>
125-
if not isinstance(x, Cluster):
126-
x = [x]
127-
else:
128-
x = fullyflatten(x.items)
129-
130-
# create a flat list of all the items in <y>
131-
if not isinstance(y, Cluster):
132-
y = [y]
133-
else:
134-
y = fullyflatten(y.items)
135-
136-
distances = []
137-
for k in x:
138-
for l in y:
139-
distances.append(self.distance(k, l))
140-
return mean(distances)
141-
142-
def complete_linkage_distance(self, x, y):
143-
"""
144-
The method to determine the distance between one cluster an another
145-
item/cluster. The distance equals to the *longest* distance from any
146-
member of one cluster to any member of the other cluster.
147-
148-
:param x: first cluster/item.
149-
:param y: second cluster/item.
150-
"""
151-
152-
# create a flat list of all the items in <x>
153-
if not isinstance(x, Cluster):
154-
x = [x]
155-
else:
156-
x = fullyflatten(x.items)
157-
158-
# create a flat list of all the items in <y>
159-
if not isinstance(y, Cluster):
160-
y = [y]
161-
else:
162-
y = fullyflatten(y.items)
163-
164-
# retrieve the minimum distance (single-linkage)
165-
maxdist = self.distance(x[0], y[0])
166-
for k in x:
167-
for l in y:
168-
maxdist = max(maxdist, self.distance(k, l))
169-
170-
return maxdist
171-
172-
def single_linkage_distance(self, x, y):
173-
"""
174-
The method to determine the distance between one cluster an another
175-
item/cluster. The distance equals to the *shortest* distance from any
176-
member of one cluster to any member of the other cluster.
177-
178-
:param x: first cluster/item.
179-
:param y: second cluster/item.
180-
"""
181-
182-
# create a flat list of all the items in <x>
183-
if not isinstance(x, Cluster):
184-
x = [x]
185-
else:
186-
x = fullyflatten(x.items)
187-
188-
# create a flat list of all the items in <y>
189-
if not isinstance(y, Cluster):
190-
y = [y]
191-
else:
192-
y = fullyflatten(y.items)
193-
194-
# retrieve the minimum distance (single-linkage)
195-
mindist = self.distance(x[0], y[0])
196-
for k in x:
197-
for l in y:
198-
mindist = min(mindist, self.distance(k, l))
199-
200-
return mindist
201-
20292
def cluster(self, matrix=None, level=None, sequence=None):
20393
"""
20494
Perform hierarchical clustering.
@@ -217,10 +107,11 @@ def cluster(self, matrix=None, level=None, sequence=None):
217107
matrix = []
218108

219109
# if the matrix only has two rows left, we are done
110+
linkage = partial(self.linkage, distance_function=self.distance)
220111
while len(matrix) > 2 or matrix == []:
221112

222113
item_item_matrix = Matrix(self._data,
223-
self.linkage,
114+
linkage,
224115
True,
225116
0)
226117
item_item_matrix.genmatrix(self.num_processes)

0 commit comments

Comments
 (0)