Skip to content

Commit 0fa018f

Browse files
committed
Fixed issue with removing points from index and rebuilding the index. Some performance improvements when no points have been removed
1 parent 7f157ed commit 0fa018f

13 files changed

+436
-251
lines changed

src/cpp/flann/algorithms/hierarchical_clustering_index.h

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
203203
*/
204204
void buildIndex()
205205
{
206+
freeIndex();
207+
cleanRemovedPoints();
208+
206209
if (branching_<2) {
207210
throw FLANNException("Branching factor must be at least 2");
208211
}
@@ -227,7 +230,6 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
227230
extendDataset(points);
228231

229232
if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
230-
freeIndex();
231233
buildIndex();
232234
}
233235
else {
@@ -300,31 +302,17 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
300302
* vec = the vector for which to search the nearest neighbors
301303
* searchParams = parameters that influence the search algorithm (checks)
302304
*/
305+
303306
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
304307
{
305-
306-
int maxChecks = searchParams.checks;
307-
308-
// Priority queue storing intermediate branches in the best-bin-first search
309-
Heap<BranchSt>* heap = new Heap<BranchSt>(size_);
310-
311-
DynamicBitset checked(size_);
312-
int checks = 0;
313-
for (int i=0; i<trees_; ++i) {
314-
findNN(tree_roots_[i], result, vec, checks, maxChecks, heap, checked);
315-
}
316-
317-
BranchSt branch;
318-
while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
319-
NodePtr node = branch.node;
320-
findNN(node, result, vec, checks, maxChecks, heap, checked);
321-
}
322-
323-
delete heap;
324-
308+
if (removed_) {
309+
findNeighbors<true>(result, vec, searchParams);
310+
}
311+
else {
312+
findNeighbors<false>(result, vec, searchParams);
313+
}
325314
}
326315

327-
328316
private:
329317

330318
struct PointInfo
@@ -538,6 +526,29 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
538526
}
539527

540528

529+
template<bool with_removed>
530+
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
531+
{
532+
int maxChecks = searchParams.checks;
533+
534+
// Priority queue storing intermediate branches in the best-bin-first search
535+
Heap<BranchSt>* heap = new Heap<BranchSt>(size_);
536+
537+
DynamicBitset checked(size_);
538+
int checks = 0;
539+
for (int i=0; i<trees_; ++i) {
540+
findNN<with_removed>(tree_roots_[i], result, vec, checks, maxChecks, heap, checked);
541+
}
542+
543+
BranchSt branch;
544+
while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
545+
NodePtr node = branch.node;
546+
findNN<with_removed>(node, result, vec, checks, maxChecks, heap, checked);
547+
}
548+
549+
delete heap;
550+
}
551+
541552

542553
/**
543554
* Performs one descent in the hierarchical k-means tree. The branches not
@@ -551,9 +562,8 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
551562
* maxChecks = maximum dataset points to checks
552563
*/
553564

554-
555-
template<typename ResultSet>
556-
void findNN(NodePtr node, ResultSet& result, const ElementType* vec, int& checks, int maxChecks,
565+
template<bool with_removed>
566+
void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
557567
Heap<BranchSt>* heap, DynamicBitset& checked)
558568
{
559569
if (node->childs.empty()) {
@@ -563,8 +573,10 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
563573

564574
for (size_t i=0; i<node->points.size(); ++i) {
565575
PointInfo& pointInfo = node->points[i];
566-
567-
if (checked.test(pointInfo.index) || removed_points_.test(pointInfo.index)) continue;
576+
if (with_removed) {
577+
if (removed_points_.test(pointInfo.index)) continue;
578+
}
579+
if (checked.test(pointInfo.index)) continue;
568580
DistanceType dist = distance_(pointInfo.point, vec, veclen_);
569581
result.addPoint(dist, pointInfo.index);
570582
checked.set(pointInfo.index);
@@ -587,7 +599,7 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
587599
}
588600
}
589601
delete[] domain_distances;
590-
findNN(node->childs[best_index],result,vec, checks, maxChecks, heap, checked);
602+
findNN<with_removed>(node->childs[best_index],result,vec, checks, maxChecks, heap, checked);
591603
}
592604
}
593605

src/cpp/flann/algorithms/kdtree_index.h

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ class KDTreeIndex : public NNIndex<Distance>
143143
*/
144144
void buildIndex()
145145
{
146+
freeIndex();
147+
cleanRemovedPoints();
148+
146149
// Create a permutable array of indices to the input vectors.
147150
std::vector<int> ind(size_);
148151
for (size_t i = 0; i < size_; ++i) {
@@ -256,10 +259,20 @@ class KDTreeIndex : public NNIndex<Distance>
256259
float epsError = 1+searchParams.eps;
257260

258261
if (maxChecks==FLANN_CHECKS_UNLIMITED) {
259-
getExactNeighbors(result, vec, epsError);
262+
if (removed_) {
263+
getExactNeighbors<true>(result, vec, epsError);
264+
}
265+
else {
266+
getExactNeighbors<false>(result, vec, epsError);
267+
}
260268
}
261269
else {
262-
getNeighbors(result, vec, maxChecks, epsError);
270+
if (removed_) {
271+
getNeighbors<true>(result, vec, maxChecks, epsError);
272+
}
273+
else {
274+
getNeighbors<false>(result, vec, maxChecks, epsError);
275+
}
263276
}
264277
}
265278

@@ -510,16 +523,16 @@ class KDTreeIndex : public NNIndex<Distance>
510523
* Performs an exact nearest neighbor search. The exact search performs a full
511524
* traversal of the tree.
512525
*/
513-
template<typename ResultSet>
514-
void getExactNeighbors(ResultSet& result, const ElementType* vec, float epsError)
526+
template<bool with_removed>
527+
void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
515528
{
516529
// checkID -= 1; /* Set a different unique ID for each search. */
517530

518531
if (trees_ > 1) {
519532
fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
520533
}
521534
if (trees_>0) {
522-
searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
535+
searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
523536
}
524537
}
525538

@@ -528,6 +541,7 @@ class KDTreeIndex : public NNIndex<Distance>
528541
* because the tree traversal is abandoned after a given number of descends in
529542
* the tree.
530543
*/
544+
template<bool with_removed>
531545
void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
532546
{
533547
int i;
@@ -539,12 +553,12 @@ class KDTreeIndex : public NNIndex<Distance>
539553

540554
/* Search once through each tree down to root. */
541555
for (i = 0; i < trees_; ++i) {
542-
searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
556+
searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
543557
}
544558

545559
/* Keep searching other branches from heap until finished. */
546560
while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
547-
searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
561+
searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
548562
}
549563

550564
delete heap;
@@ -556,6 +570,7 @@ class KDTreeIndex : public NNIndex<Distance>
556570
* higher levels, all exemplars below this level must have a distance of
557571
* at least "mindistsq".
558572
*/
573+
template<bool with_removed>
559574
void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
560575
float epsError, Heap<BranchSt>* heap, DynamicBitset& checked)
561576
{
@@ -567,8 +582,11 @@ class KDTreeIndex : public NNIndex<Distance>
567582
/* If this is a leaf node, then do check and return. */
568583
if ((node->child1 == NULL)&&(node->child2 == NULL)) {
569584
int index = node->divfeat;
585+
if (with_removed) {
586+
if (removed_points_.test(index)) return;
587+
}
570588
/* Do not check same node more than once when searching multiple trees. */
571-
if ( checked.test(index) || removed_points_.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
589+
if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
572590
checked.set(index);
573591
checkCount++;
574592

@@ -598,19 +616,21 @@ class KDTreeIndex : public NNIndex<Distance>
598616
}
599617

600618
/* Call recursively to search next level down. */
601-
searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
619+
searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
602620
}
603621

604622
/**
605623
* Performs an exact search in the tree starting from a node.
606624
*/
607-
template<typename ResultSet>
608-
void searchLevelExact(ResultSet& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
625+
template<bool with_removed>
626+
void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
609627
{
610628
/* If this is a leaf node, then do check and return. */
611629
if ((node->child1 == NULL)&&(node->child2 == NULL)) {
612630
int index = node->divfeat;
613-
if (removed_points_.test(index)) return; // ignore removed points
631+
if (with_removed) {
632+
if (removed_points_.test(index)) return; // ignore removed points
633+
}
614634
DistanceType dist = distance_(node->point, vec, veclen_);
615635
result_set.addPoint(dist,index);
616636

@@ -634,10 +654,10 @@ class KDTreeIndex : public NNIndex<Distance>
634654
DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
635655

636656
/* Call recursively to search next level down. */
637-
searchLevelExact(result_set, vec, bestChild, mindist, epsError);
657+
searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
638658

639659
if (mindist*epsError<=result_set.worstDist()) {
640-
searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
660+
searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
641661
}
642662
}
643663

0 commit comments

Comments
 (0)