Skip to content

Commit 9125041

Browse files
committed
Merge pull request flann-lib#152 from pemmanuelviel/master
Better results with centers far from each other when selecting them in the chooseCentersKMeanspp stage
2 parents 294d908 + 0776429 commit 9125041

File tree

1 file changed

+68
-2
lines changed

1 file changed

+68
-2
lines changed

src/cpp/flann/algorithms/center_chooser.h

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,63 @@
1313
namespace flann
1414
{
1515

16+
template <typename Distance, typename ElementType>
17+
struct squareDistance
18+
{
19+
typedef typename Distance::ResultType ResultType;
20+
ResultType operator()( ResultType dist ) { return dist*dist; }
21+
};
22+
23+
24+
template <typename ElementType>
25+
struct squareDistance<L2_Simple<ElementType>, ElementType>
26+
{
27+
typedef typename L2_Simple<ElementType>::ResultType ResultType;
28+
ResultType operator()( ResultType dist ) { return dist; }
29+
};
30+
31+
template <typename ElementType>
32+
struct squareDistance<L2_3D<ElementType>, ElementType>
33+
{
34+
typedef typename L2_3D<ElementType>::ResultType ResultType;
35+
ResultType operator()( ResultType dist ) { return dist; }
36+
};
37+
38+
template <typename ElementType>
39+
struct squareDistance<L2<ElementType>, ElementType>
40+
{
41+
typedef typename L2<ElementType>::ResultType ResultType;
42+
ResultType operator()( ResultType dist ) { return dist; }
43+
};
44+
45+
46+
template <typename ElementType>
47+
struct squareDistance<HellingerDistance<ElementType>, ElementType>
48+
{
49+
typedef typename HellingerDistance<ElementType>::ResultType ResultType;
50+
ResultType operator()( ResultType dist ) { return dist; }
51+
};
52+
53+
54+
template <typename ElementType>
55+
struct squareDistance<ChiSquareDistance<ElementType>, ElementType>
56+
{
57+
typedef typename ChiSquareDistance<ElementType>::ResultType ResultType;
58+
ResultType operator()( ResultType dist ) { return dist; }
59+
};
60+
61+
62+
template <typename Distance>
63+
typename Distance::ResultType ensureSquareDistance( typename Distance::ResultType dist )
64+
{
65+
typedef typename Distance::ElementType ElementType;
66+
67+
squareDistance<Distance, ElementType> dummy;
68+
return dummy( dist );
69+
}
70+
71+
72+
1673
template <typename Distance>
1774
class CenterChooser
1875
{
@@ -176,8 +233,11 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
176233
assert(index >=0 && index < n);
177234
centers[0] = indices[index];
178235

236+
// Computing distance^2 will have the advantage of even higher probability further to pick new centers
237+
// far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article)
179238
for (int i = 0; i < n; i++) {
180239
closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
240+
closestDistSq[i] = ensureSquareDistance<Distance>( closestDistSq[i] );
181241
currentPot += closestDistSq[i];
182242
}
183243

@@ -203,7 +263,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
203263

204264
// Compute the new potential
205265
double newPot = 0;
206-
for (int i = 0; i < n; i++) newPot += std::min( distance_(points_[indices[i]], points_[indices[index]], cols_), closestDistSq[i] );
266+
for (int i = 0; i < n; i++) {
267+
DistanceType dist = distance_(points_[indices[i]], points_[indices[index]], cols_);
268+
newPot += std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
269+
}
207270

208271
// Store the best result
209272
if ((bestNewPot < 0)||(newPot < bestNewPot)) {
@@ -215,7 +278,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
215278
// Add the appropriate center
216279
centers[centerCount] = indices[bestNewIndex];
217280
currentPot = bestNewPot;
218-
for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_), closestDistSq[i] );
281+
for (int i = 0; i < n; i++) {
282+
DistanceType dist = distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_);
283+
closestDistSq[i] = std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
284+
}
219285
}
220286

221287
centers_length = centerCount;

0 commit comments

Comments
 (0)