1313namespace flann
1414{
1515
16+ template <typename Distance, typename ElementType>
17+ struct squareDistance
18+ {
19+ typedef typename Distance::ResultType ResultType;
20+ ResultType operator ()( ResultType dist ) { return dist*dist; }
21+ };
22+
23+
24+ template <typename ElementType>
25+ struct squareDistance <L2_Simple<ElementType>, ElementType>
26+ {
27+ typedef typename L2_Simple<ElementType>::ResultType ResultType;
28+ ResultType operator ()( ResultType dist ) { return dist; }
29+ };
30+
31+ template <typename ElementType>
32+ struct squareDistance <L2_3D<ElementType>, ElementType>
33+ {
34+ typedef typename L2_3D<ElementType>::ResultType ResultType;
35+ ResultType operator ()( ResultType dist ) { return dist; }
36+ };
37+
38+ template <typename ElementType>
39+ struct squareDistance <L2<ElementType>, ElementType>
40+ {
41+ typedef typename L2<ElementType>::ResultType ResultType;
42+ ResultType operator ()( ResultType dist ) { return dist; }
43+ };
44+
45+
46+ template <typename ElementType>
47+ struct squareDistance <HellingerDistance<ElementType>, ElementType>
48+ {
49+ typedef typename HellingerDistance<ElementType>::ResultType ResultType;
50+ ResultType operator ()( ResultType dist ) { return dist; }
51+ };
52+
53+
54+ template <typename ElementType>
55+ struct squareDistance <ChiSquareDistance<ElementType>, ElementType>
56+ {
57+ typedef typename ChiSquareDistance<ElementType>::ResultType ResultType;
58+ ResultType operator ()( ResultType dist ) { return dist; }
59+ };
60+
61+
62+ template <typename Distance>
63+ typename Distance::ResultType ensureSquareDistance ( typename Distance::ResultType dist )
64+ {
65+ typedef typename Distance::ElementType ElementType;
66+
67+ squareDistance<Distance, ElementType> dummy;
68+ return dummy ( dist );
69+ }
70+
71+
72+
1673template <typename Distance>
1774class CenterChooser
1875{
@@ -176,8 +233,11 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
176233 assert (index >=0 && index < n);
177234 centers[0 ] = indices[index];
178235
236+ // Computing distance^2 will have the advantage of even higher probability further to pick new centers
237+ // far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article)
179238 for (int i = 0 ; i < n; i++) {
180239 closestDistSq[i] = distance_ (points_[indices[i]], points_[indices[index]], cols_);
240+ closestDistSq[i] = ensureSquareDistance<Distance>( closestDistSq[i] );
181241 currentPot += closestDistSq[i];
182242 }
183243
@@ -203,7 +263,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
203263
204264 // Compute the new potential
205265 double newPot = 0 ;
206- for (int i = 0 ; i < n; i++) newPot += std::min ( distance_ (points_[indices[i]], points_[indices[index]], cols_), closestDistSq[i] );
266+ for (int i = 0 ; i < n; i++) {
267+ DistanceType dist = distance_ (points_[indices[i]], points_[indices[index]], cols_);
268+ newPot += std::min ( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
269+ }
207270
208271 // Store the best result
209272 if ((bestNewPot < 0 )||(newPot < bestNewPot)) {
@@ -215,7 +278,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
215278 // Add the appropriate center
216279 centers[centerCount] = indices[bestNewIndex];
217280 currentPot = bestNewPot;
218- for (int i = 0 ; i < n; i++) closestDistSq[i] = std::min ( distance_ (points_[indices[i]], points_[indices[bestNewIndex]], cols_), closestDistSq[i] );
281+ for (int i = 0 ; i < n; i++) {
282+ DistanceType dist = distance_ (points_[indices[i]], points_[indices[bestNewIndex]], cols_);
283+ closestDistSq[i] = std::min ( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
284+ }
219285 }
220286
221287 centers_length = centerCount;
0 commit comments