Skip to content

Commit dbf1801

Browse files
committed
Merge pull request flann-lib#157 from pemmanuelviel/groupWiseCenterChooser
Add a new method for initializing KMeans centers that leads to better clusters and thus better retrieval when final centers have to be existing keypoints instead of clusters barycenters.
2 parents fa619c8 + c06f813 commit dbf1801

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

src/cpp/flann/algorithms/center_chooser.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,94 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
291291
};
292292

293293

294+
295+
/**
296+
* Chooses the initial centers in a way inspired by Gonzales (by Pierre-Emmanuel Viel):
297+
* select the first point of the list as a candidate, then parse the points list. If another
298+
* point is further than current candidate from the other centers, test if it is a good center
299+
* of a local aggregation. If it is, replace current candidate by this point. And so on...
300+
*
301+
* Used with KMeansIndex that computes centers coordinates by averaging positions of clusters points,
302+
* this doesn't make a real difference with previous methods. But used with HierarchicalClusteringIndex
303+
* class that pick centers among existing points instead of computing the barycenters, there is a real
304+
* improvement.
305+
*/
306+
template <typename Distance>
307+
class GroupWiseCenterChooser : public CenterChooser<Distance>
308+
{
309+
public:
310+
typedef typename Distance::ElementType ElementType;
311+
typedef typename Distance::ResultType DistanceType;
312+
313+
using CenterChooser<Distance>::points_;
314+
using CenterChooser<Distance>::distance_;
315+
using CenterChooser<Distance>::cols_;
316+
317+
GroupWiseCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
318+
CenterChooser<Distance>(distance, points) {}
319+
320+
void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
321+
{
322+
const float kSpeedUpFactor = 1.3f;
323+
324+
int n = indices_length;
325+
326+
DistanceType* closestDistSq = new DistanceType[n];
327+
328+
// Choose one random center and set the closestDistSq values
329+
int index = rand_int(n);
330+
assert(index >=0 && index < n);
331+
centers[0] = indices[index];
332+
333+
for (int i = 0; i < n; i++) {
334+
closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
335+
}
336+
337+
338+
// Choose each center
339+
int centerCount;
340+
for (centerCount = 1; centerCount < k; centerCount++) {
341+
342+
// Repeat several trials
343+
double bestNewPot = -1;
344+
int bestNewIndex = 0;
345+
DistanceType furthest = 0;
346+
for (index = 0; index < n; index++) {
347+
348+
// We will test only the potential of the points further than current candidate
349+
if( closestDistSq[index] > kSpeedUpFactor * (float)furthest ) {
350+
351+
// Compute the new potential
352+
double newPot = 0;
353+
for (int i = 0; i < n; i++) {
354+
newPot += std::min( distance_(points_[indices[i]], points_[indices[index]], cols_)
355+
, closestDistSq[i] );
356+
}
357+
358+
// Store the best result
359+
if ((bestNewPot < 0)||(newPot <= bestNewPot)) {
360+
bestNewPot = newPot;
361+
bestNewIndex = index;
362+
furthest = closestDistSq[index];
363+
}
364+
}
365+
}
366+
367+
// Add the appropriate center
368+
centers[centerCount] = indices[bestNewIndex];
369+
for (int i = 0; i < n; i++) {
370+
closestDistSq[i] = std::min( distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_)
371+
, closestDistSq[i] );
372+
}
373+
}
374+
375+
centers_length = centerCount;
376+
377+
delete[] closestDistSq;
378+
}
379+
};
380+
381+
294382
}
295383

296384

src/cpp/flann/algorithms/hierarchical_clustering_index.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
165165
case FLANN_CENTERS_KMEANSPP:
166166
chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
167167
break;
168+
case FLANN_CENTERS_GROUPWISE:
169+
chooseCenters_ = new GroupWiseCenterChooser<Distance>(distance_, points_);
170+
break;
168171
default:
169172
throw FLANNException("Unknown algorithm for choosing initial centers.");
170173
}

src/cpp/flann/defines.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ enum flann_centers_init_t
9797
FLANN_CENTERS_RANDOM = 0,
9898
FLANN_CENTERS_GONZALES = 1,
9999
FLANN_CENTERS_KMEANSPP = 2,
100+
FLANN_CENTERS_GROUPWISE = 3,
100101
};
101102

102103
enum flann_log_level_t

0 commit comments

Comments
 (0)