Skip to content

Commit 0776429

Browse files
committed
As some processed distances are already ^2, use template to select whether or not we have to ^2 in KMeanspp
1 parent 3008afa commit 0776429

File tree

1 file changed

+60
-3
lines changed

1 file changed

+60
-3
lines changed

src/cpp/flann/algorithms/center_chooser.h

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,63 @@
1313
namespace flann
1414
{
1515

16+
template <typename Distance, typename ElementType>
17+
struct squareDistance
18+
{
19+
typedef typename Distance::ResultType ResultType;
20+
ResultType operator()( ResultType dist ) { return dist*dist; }
21+
};
22+
23+
24+
template <typename ElementType>
25+
struct squareDistance<L2_Simple<ElementType>, ElementType>
26+
{
27+
typedef typename L2_Simple<ElementType>::ResultType ResultType;
28+
ResultType operator()( ResultType dist ) { return dist; }
29+
};
30+
31+
template <typename ElementType>
32+
struct squareDistance<L2_3D<ElementType>, ElementType>
33+
{
34+
typedef typename L2_3D<ElementType>::ResultType ResultType;
35+
ResultType operator()( ResultType dist ) { return dist; }
36+
};
37+
38+
template <typename ElementType>
39+
struct squareDistance<L2<ElementType>, ElementType>
40+
{
41+
typedef typename L2<ElementType>::ResultType ResultType;
42+
ResultType operator()( ResultType dist ) { return dist; }
43+
};
44+
45+
46+
template <typename ElementType>
47+
struct squareDistance<HellingerDistance<ElementType>, ElementType>
48+
{
49+
typedef typename HellingerDistance<ElementType>::ResultType ResultType;
50+
ResultType operator()( ResultType dist ) { return dist; }
51+
};
52+
53+
54+
template <typename ElementType>
55+
struct squareDistance<ChiSquareDistance<ElementType>, ElementType>
56+
{
57+
typedef typename ChiSquareDistance<ElementType>::ResultType ResultType;
58+
ResultType operator()( ResultType dist ) { return dist; }
59+
};
60+
61+
62+
template <typename Distance>
63+
typename Distance::ResultType ensureSquareDistance( typename Distance::ResultType dist )
64+
{
65+
typedef typename Distance::ElementType ElementType;
66+
67+
squareDistance<Distance, ElementType> dummy;
68+
return dummy( dist );
69+
}
70+
71+
72+
1673
template <typename Distance>
1774
class CenterChooser
1875
{
@@ -180,7 +237,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
180237
// far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article)
181238
for (int i = 0; i < n; i++) {
182239
closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
183-
closestDistSq[i] *= closestDistSq[i];
240+
closestDistSq[i] = ensureSquareDistance<Distance>( closestDistSq[i] );
184241
currentPot += closestDistSq[i];
185242
}
186243

@@ -208,7 +265,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
208265
double newPot = 0;
209266
for (int i = 0; i < n; i++) {
210267
DistanceType dist = distance_(points_[indices[i]], points_[indices[index]], cols_);
211-
newPot += std::min( dist*dist, closestDistSq[i] );
268+
newPot += std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
212269
}
213270

214271
// Store the best result
@@ -223,7 +280,7 @@ class KMeansppCenterChooser : public CenterChooser<Distance>
223280
currentPot = bestNewPot;
224281
for (int i = 0; i < n; i++) {
225282
DistanceType dist = distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_);
226-
closestDistSq[i] = std::min( dist*dist, closestDistSq[i] );
283+
closestDistSq[i] = std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
227284
}
228285
}
229286

0 commit comments

Comments
 (0)