Skip to content

Commit 0f3789c

Browse files
committed
Fixing bug related to buiding some indexes when the dataset is not passed in the constructor
1 parent e10b2ed commit 0f3789c

10 files changed

+112
-30
lines changed

src/cpp/flann/algorithms/center_chooser.h

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,11 @@ class CenterChooser
2020
typedef typename Distance::ElementType ElementType;
2121
typedef typename Distance::ResultType DistanceType;
2222

23-
CenterChooser(const Distance& distance) : distance_(distance) {};
24-
25-
void setDataset(const flann::Matrix<ElementType>& dataset)
26-
{
27-
dataset_ = dataset;
28-
}
23+
CenterChooser(const Distance& distance, const std::vector<ElementType*>& points) : distance_(distance), points_(points) {};
2924

3025
virtual ~CenterChooser() {};
26+
27+
void setDataSize(size_t cols) { cols_ = cols; }
3128

3229
/**
3330
* Chooses cluster centers
@@ -41,8 +38,9 @@ class CenterChooser
4138
virtual void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length) = 0;
4239

4340
protected:
44-
flann::Matrix<ElementType> dataset_;
45-
Distance distance_;
41+
const Distance distance_;
42+
const std::vector<ElementType*>& points_;
43+
size_t cols_;
4644
};
4745

4846

@@ -52,11 +50,12 @@ class RandomCenterChooser : public CenterChooser<Distance>
5250
public:
5351
typedef typename Distance::ElementType ElementType;
5452
typedef typename Distance::ResultType DistanceType;
55-
using CenterChooser<Distance>::dataset_;
53+
using CenterChooser<Distance>::points_;
5654
using CenterChooser<Distance>::distance_;
55+
using CenterChooser<Distance>::cols_;
5756

58-
RandomCenterChooser(const Distance& distance) :
59-
CenterChooser<Distance>(distance) {}
57+
RandomCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
58+
CenterChooser<Distance>(distance, points) {}
6059

6160
void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
6261
{
@@ -77,7 +76,7 @@ class RandomCenterChooser : public CenterChooser<Distance>
7776
centers[index] = indices[rnd];
7877

7978
for (int j=0; j<index; ++j) {
80-
DistanceType sq = distance_(dataset_[centers[index]], dataset_[centers[j]], dataset_.cols);
79+
DistanceType sq = distance_(points_[centers[index]], points_[centers[j]], cols_);
8180
if (sq<1e-16) {
8281
duplicate = true;
8382
}
@@ -101,10 +100,12 @@ class GonzalesCenterChooser : public CenterChooser<Distance>
101100
typedef typename Distance::ElementType ElementType;
102101
typedef typename Distance::ResultType DistanceType;
103102

104-
using CenterChooser<Distance>::dataset_;
103+
using CenterChooser<Distance>::points_;
105104
using CenterChooser<Distance>::distance_;
105+
using CenterChooser<Distance>::cols_;
106106

107-
GonzalesCenterChooser(const Distance& distance) : CenterChooser<Distance>( distance) {}
107+
GonzalesCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
108+
CenterChooser<Distance>(distance, points) {}
108109

109110
void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
110111
{
@@ -121,9 +122,9 @@ class GonzalesCenterChooser : public CenterChooser<Distance>
121122
int best_index = -1;
122123
DistanceType best_val = 0;
123124
for (int j=0; j<n; ++j) {
124-
DistanceType dist = distance_(dataset_[centers[0]],dataset_[indices[j]],dataset_.cols);
125+
DistanceType dist = distance_(points_[centers[0]],points_[indices[j]],cols_);
125126
for (int i=1; i<index; ++i) {
126-
DistanceType tmp_dist = distance_(dataset_[centers[i]],dataset_[indices[j]],dataset_.cols);
127+
DistanceType tmp_dist = distance_(points_[centers[i]],points_[indices[j]],cols_);
127128
if (tmp_dist<dist) {
128129
dist = tmp_dist;
129130
}
@@ -156,10 +157,12 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
156157
typedef typename Distance::ElementType ElementType;
157158
typedef typename Distance::ResultType DistanceType;
158159

159-
using CenterChooser<Distance>::dataset_;
160+
using CenterChooser<Distance>::points_;
160161
using CenterChooser<Distance>::distance_;
162+
using CenterChooser<Distance>::cols_;
161163

162-
KMeansppCenterChooser(const Distance& distance) : CenterChooser<Distance>(distance) {}
164+
KMeansppCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
165+
CenterChooser<Distance>(distance, points) {}
163166

164167
void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
165168
{
@@ -174,7 +177,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
174177
centers[0] = indices[index];
175178

176179
for (int i = 0; i < n; i++) {
177-
closestDistSq[i] = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols);
180+
closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
178181
currentPot += closestDistSq[i];
179182
}
180183

@@ -200,7 +203,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
200203

201204
// Compute the new potential
202205
double newPot = 0;
203-
for (int i = 0; i < n; i++) newPot += std::min( distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols), closestDistSq[i] );
206+
for (int i = 0; i < n; i++) newPot += std::min( distance_(points_[indices[i]], points_[indices[index]], cols_), closestDistSq[i] );
204207

205208
// Store the best result
206209
if ((bestNewPot < 0)||(newPot < bestNewPot)) {
@@ -212,7 +215,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
212215
// Add the appropriate center
213216
centers[centerCount] = indices[bestNewIndex];
214217
currentPot = bestNewPot;
215-
for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(dataset_[indices[i]], dataset_[indices[bestNewIndex]], dataset_.cols), closestDistSq[i] );
218+
for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_), closestDistSq[i] );
216219
}
217220

218221
centers_length = centerCount;

src/cpp/flann/algorithms/hierarchical_clustering_index.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
126126
leaf_max_size_ = get_param(index_params_,"leaf_max_size",100);
127127

128128
initCenterChooser();
129-
chooseCenters_->setDataset(inputData);
130129

131130
setDataset(inputData);
132131
}
@@ -158,13 +157,13 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
158157
{
159158
switch(centers_init_) {
160159
case FLANN_CENTERS_RANDOM:
161-
chooseCenters_ = new RandomCenterChooser<Distance>(distance_);
160+
chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
162161
break;
163162
case FLANN_CENTERS_GONZALES:
164-
chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_);
163+
chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
165164
break;
166165
case FLANN_CENTERS_KMEANSPP:
167-
chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_);
166+
chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
168167
break;
169168
default:
170169
throw FLANNException("Unknown algorithm for choosing initial centers.");
@@ -296,6 +295,8 @@ class HierarchicalClusteringIndex : public NNIndex<Distance>
296295
*/
297296
void buildIndexImpl()
298297
{
298+
chooseCenters_->setDataSize(veclen_);
299+
299300
if (branching_<2) {
300301
throw FLANNException("Branching factor must be at least 2");
301302
}

src/cpp/flann/algorithms/kmeans_index.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ class KMeansIndex : public NNIndex<Distance>
117117
cb_index_ = get_param(params,"cb_index",0.4f);
118118

119119
initCenterChooser();
120-
chooseCenters_->setDataset(inputData);
121-
122120
setDataset(inputData);
123121
}
124122

@@ -168,13 +166,13 @@ class KMeansIndex : public NNIndex<Distance>
168166
{
169167
switch(centers_init_) {
170168
case FLANN_CENTERS_RANDOM:
171-
chooseCenters_ = new RandomCenterChooser<Distance>(distance_);
169+
chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
172170
break;
173171
case FLANN_CENTERS_GONZALES:
174-
chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_);
172+
chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
175173
break;
176174
case FLANN_CENTERS_KMEANSPP:
177-
chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_);
175+
chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
178176
break;
179177
default:
180178
throw FLANNException("Unknown algorithm for choosing initial centers.");
@@ -330,6 +328,8 @@ class KMeansIndex : public NNIndex<Distance>
330328
*/
331329
void buildIndexImpl()
332330
{
331+
chooseCenters_->setDataSize(veclen_);
332+
333333
if (branching_<2) {
334334
throw FLANNException("Branching factor must be at least 2");
335335
}

test/flann_hierarchical_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ TEST_F(HierarchicalIndex_Brief100K, TestSearch)
6363
query, indices, dists, k_nn_, flann::SearchParams(2000), 0.9, gt_indices, gt_dists);
6464
}
6565

66+
TEST_F(HierarchicalIndex_Brief100K, TestSearch2)
67+
{
68+
TestSearch2<Distance>(data, flann::HierarchicalClusteringIndexParams(),
69+
query, indices, dists, k_nn_, flann::SearchParams(2000), 0.9, gt_indices, gt_dists);
70+
}
71+
72+
6673
TEST_F(HierarchicalIndex_Brief100K, TestAddIncremental)
6774
{
6875
TestAddIncremental<Distance>(data, flann::HierarchicalClusteringIndexParams(),
@@ -101,6 +108,9 @@ TEST_F(HierarchicalIndex_Brief100K, TestCopy2)
101108
}
102109

103110

111+
112+
113+
104114
int main(int argc, char** argv)
105115
{
106116
testing::InitGoogleTest(&argc, argv);

test/flann_kdtree_single_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ TEST_F(KDTreeSingle, TestSearch)
1919
query, indices, dists, knn, flann::SearchParams(-1), 0.99, gt_indices);
2020
}
2121

22+
TEST_F(KDTreeSingle, TestSearch2)
23+
{
24+
TestSearch2<L2_Simple<float> >(data, flann::KDTreeSingleIndexParams(12, false),
25+
query, indices, dists, knn, flann::SearchParams(-1), 0.99, gt_indices);
26+
}
27+
28+
2229
TEST_F(KDTreeSingle, TestSearchPadded)
2330
{
2431
flann::Matrix<float> data_padded;

test/flann_kdtree_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ TEST_F(KDTree_SIFT10K, TestSearch)
2222
dists, knn, flann::SearchParams(256), 0.75, gt_indices);
2323
}
2424

25+
TEST_F(KDTree_SIFT10K, TestSearch2)
26+
{
27+
TestSearch2<flann::L2<float> >(data, flann::KDTreeIndexParams(4), query, indices,
28+
dists, knn, flann::SearchParams(256), 0.75, gt_indices);
29+
}
30+
31+
2532
TEST_F(KDTree_SIFT10K, TestAddIncremental)
2633
{
2734
TestAddIncremental<flann::L2<float> >(data, flann::KDTreeIndexParams(4), query, indices,

test/flann_kmeans_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ TEST_F(KMeans_SIFT10K, TestSearch)
2525
query, indices, dists, knn, flann::SearchParams(128), 0.75, gt_indices);
2626
}
2727

28+
TEST_F(KMeans_SIFT10K, TestSearch2)
29+
{
30+
TestSearch2<flann::L2<float> >(data, flann::KMeansIndexParams(7, 3, FLANN_CENTERS_RANDOM, 0.4),
31+
query, indices, dists, knn, flann::SearchParams(128), 0.75, gt_indices);
32+
}
33+
2834

2935
TEST_F(KMeans_SIFT10K, TestAddIncremental)
3036
{

test/flann_linear_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ TEST_F(Linear_SIFT10K, TestSearch)
2323
query, indices, dists, knn, flann::SearchParams(0), 1.0, gt_indices);
2424
}
2525

26+
TEST_F(Linear_SIFT10K, TestSearch2)
27+
{
28+
TestSearch2<flann::L2<float> >(data, flann::LinearIndexParams(),
29+
query, indices, dists, knn, flann::SearchParams(0), 1.0, gt_indices);
30+
}
31+
2632
TEST_F(Linear_SIFT10K, TestRemove)
2733
{
2834
TestRemove<flann::L2<float> >(data, flann::LinearIndexParams(),

test/flann_lsh_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ TEST_F(LshIndex_Brief100K, TestSearch)
6767
query, indices, dists, k_nn_, flann::SearchParams(-1), 0.9, gt_indices, gt_dists);
6868
}
6969

70+
TEST_F(LshIndex_Brief100K, TestSearch2)
71+
{
72+
TestSearch2<Distance>(data, flann::LshIndexParams(12, 20, 2),
73+
query, indices, dists, k_nn_, flann::SearchParams(-1), 0.9, gt_indices, gt_dists);
74+
}
7075

7176
TEST_F(LshIndex_Brief100K, TestAddIncremental)
7277
{

test/flann_tests.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,43 @@ class FLANNTestFixture : public ::testing::Test {
164164
printf("Precision: %g\n", precision);
165165
}
166166

167+
168+
template<typename Distance>
169+
void TestSearch2(const flann::Matrix<typename Distance::ElementType>& data,
170+
const flann::IndexParams& index_params,
171+
const flann::Matrix<typename Distance::ElementType>& query,
172+
flann::Matrix<size_t>& indices,
173+
flann::Matrix<typename Distance::ResultType>& dists,
174+
size_t knn,
175+
const flann::SearchParams& search_params,
176+
float expected_precision,
177+
const flann::Matrix<size_t>& gt_indices,
178+
const flann::Matrix<typename Distance::ResultType>& gt_dists = flann::Matrix<typename Distance::ResultType>())
179+
{
180+
flann::seed_random(0);
181+
Index<Distance> index(index_params);
182+
char message[256];
183+
const char* index_name = index_type_to_name(index.getType());
184+
sprintf(message, "Building %s index... ", index_name);
185+
start_timer( message );
186+
index.buildIndex(data);
187+
printf("done (%g seconds)\n", stop_timer());
188+
189+
start_timer("Searching KNN...");
190+
index.knnSearch(query, indices, dists, knn, search_params );
191+
printf("done (%g seconds)\n", stop_timer());
192+
193+
float precision;
194+
if (gt_dists.ptr()==NULL) {
195+
precision = compute_precision(gt_indices, indices);
196+
}
197+
else {
198+
precision = computePrecisionDiscrete(gt_dists, dists);
199+
}
200+
EXPECT_GE(precision, expected_precision);
201+
printf("Precision: %g\n", precision);
202+
}
203+
167204
template<typename Distance>
168205
void TestAddIncremental(const flann::Matrix<typename Distance::ElementType>& data,
169206
const flann::IndexParams& index_params,

0 commit comments

Comments
 (0)