Skip to content

Commit 3008afa

Browse files
committed
Pick centers in KMeans++ with a probability proportional to their distance^2, instead of simple distance, to previous centers
1 parent 294d908 commit 3008afa

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/cpp/flann/algorithms/center_chooser.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,11 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
176176
assert(index >=0 && index < n);
177177
centers[0] = indices[index];
178178

179+
// Computing distance^2 will have the advantage of even higher probability further to pick new centers
180+
// far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article)
179181
for (int i = 0; i < n; i++) {
180182
closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
183+
closestDistSq[i] *= closestDistSq[i];
181184
currentPot += closestDistSq[i];
182185
}
183186

@@ -203,7 +206,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
203206

204207
// Compute the new potential
205208
double newPot = 0;
206-
for (int i = 0; i < n; i++) newPot += std::min( distance_(points_[indices[i]], points_[indices[index]], cols_), closestDistSq[i] );
209+
for (int i = 0; i < n; i++) {
210+
DistanceType dist = distance_(points_[indices[i]], points_[indices[index]], cols_);
211+
newPot += std::min( dist*dist, closestDistSq[i] );
212+
}
207213

208214
// Store the best result
209215
if ((bestNewPot < 0)||(newPot < bestNewPot)) {
@@ -215,7 +221,10 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
215221
// Add the appropriate center
216222
centers[centerCount] = indices[bestNewIndex];
217223
currentPot = bestNewPot;
218-
for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_), closestDistSq[i] );
224+
for (int i = 0; i < n; i++) {
225+
DistanceType dist = distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_);
226+
closestDistSq[i] = std::min( dist*dist, closestDistSq[i] );
227+
}
219228
}
220229

221230
centers_length = centerCount;

0 commit comments

Comments
 (0)