13
13
namespace flann
14
14
{
15
15
16
+ template <typename Distance, typename ElementType>
17
+ struct squareDistance
18
+ {
19
+ typedef typename Distance::ResultType ResultType;
20
+ ResultType operator ()( ResultType dist ) { return dist*dist; }
21
+ };
22
+
23
+
24
+ template <typename ElementType>
25
+ struct squareDistance <L2_Simple<ElementType>, ElementType>
26
+ {
27
+ typedef typename L2_Simple<ElementType>::ResultType ResultType;
28
+ ResultType operator ()( ResultType dist ) { return dist; }
29
+ };
30
+
31
+ template <typename ElementType>
32
+ struct squareDistance <L2_3D<ElementType>, ElementType>
33
+ {
34
+ typedef typename L2_3D<ElementType>::ResultType ResultType;
35
+ ResultType operator ()( ResultType dist ) { return dist; }
36
+ };
37
+
38
+ template <typename ElementType>
39
+ struct squareDistance <L2<ElementType>, ElementType>
40
+ {
41
+ typedef typename L2<ElementType>::ResultType ResultType;
42
+ ResultType operator ()( ResultType dist ) { return dist; }
43
+ };
44
+
45
+
46
+ template <typename ElementType>
47
+ struct squareDistance <HellingerDistance<ElementType>, ElementType>
48
+ {
49
+ typedef typename HellingerDistance<ElementType>::ResultType ResultType;
50
+ ResultType operator ()( ResultType dist ) { return dist; }
51
+ };
52
+
53
+
54
+ template <typename ElementType>
55
+ struct squareDistance <ChiSquareDistance<ElementType>, ElementType>
56
+ {
57
+ typedef typename ChiSquareDistance<ElementType>::ResultType ResultType;
58
+ ResultType operator ()( ResultType dist ) { return dist; }
59
+ };
60
+
61
+
62
+ template <typename Distance>
63
+ typename Distance::ResultType ensureSquareDistance ( typename Distance::ResultType dist )
64
+ {
65
+ typedef typename Distance::ElementType ElementType;
66
+
67
+ squareDistance<Distance, ElementType> dummy;
68
+ return dummy ( dist );
69
+ }
70
+
71
+
72
+
16
73
template <typename Distance>
17
74
class CenterChooser
18
75
{
@@ -180,7 +237,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
180
237
// far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article)
181
238
for (int i = 0 ; i < n; i++) {
182
239
closestDistSq[i] = distance_ (points_[indices[i]], points_[indices[index]], cols_);
183
- closestDistSq[i] *= closestDistSq[i];
240
+ closestDistSq[i] = ensureSquareDistance<Distance>( closestDistSq[i] ) ;
184
241
currentPot += closestDistSq[i];
185
242
}
186
243
@@ -208,7 +265,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
208
265
double newPot = 0 ;
209
266
for (int i = 0 ; i < n; i++) {
210
267
DistanceType dist = distance_ (points_[indices[i]], points_[indices[index]], cols_);
211
- newPot += std::min ( dist*dist , closestDistSq[i] );
268
+ newPot += std::min ( ensureSquareDistance<Distance>( dist) , closestDistSq[i] );
212
269
}
213
270
214
271
// Store the best result
@@ -223,7 +280,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
223
280
currentPot = bestNewPot;
224
281
for (int i = 0 ; i < n; i++) {
225
282
DistanceType dist = distance_ (points_[indices[i]], points_[indices[bestNewIndex]], cols_);
226
- closestDistSq[i] = std::min ( dist*dist , closestDistSq[i] );
283
+ closestDistSq[i] = std::min ( ensureSquareDistance<Distance>( dist) , closestDistSq[i] );
227
284
}
228
285
}
229
286
0 commit comments