11#include " module_base/grid/batch.h"
22
3- #include " module_base/blas_connector.h"
4- #include " module_base/lapack_connector.h"
53#include < algorithm>
64#include < cassert>
75#include < iterator>
86
7+ #include " module_base/blas_connector.h"
8+ #include " module_base/lapack_connector.h"
9+
910namespace {
1011
1112/* *
12- * @brief Bisect a set of points by the "MaxMin" algorithm.
13+ * @brief Divide a set of points into two subsets by the "MaxMin" algorithm.
1314 *
14- * Given a selected set of grid points specified by the indices `idx`,
15- * bisect this set by a cut plane {x|n^T*(x-c) = 0} where the normal
16- * vector n and the point c are determined by the "MaxMin" problem:
15+ * This function divides a given set of grid points by a cut plane
16+ * {x|n^T*(x-c) = 0} where the normal vector n and the point c are
17+ * determined by the "MaxMin" problem:
1718 *
18- * max min sum_{i=1}^{m} [n^T* (r[idx[i]] - c)]^2
19+ * max min sum_{i=1}^{m} [n^T * (r[idx[i]] - c)]^2
1920 * n c
2021 *
2122 * here r[j] = (grid[3*j], grid[3*j+1], grid[3*j+2]) is the position of
@@ -26,26 +27,30 @@ namespace {
2627 * of the matrix R*R^T, where R is the matrix whose i-th column is
2728 * r[idx[i]] - c.
2829 *
29- * param[in] m Number of the selected points (size of idx).
30- * param[in] grid Coordinates of all grid points.
30+ * @ param[in] m Number of selected points (length of idx).
31+ * @ param[in] grid Coordinates of all grid points.
3132 * grid[3*j], grid[3*j+1], grid[3*j+2] are the
3233 * x, y, z coordinates of the j-th point.
33- * The size of grid is at least 3*m.
34- * param[in,out] idx Indices of the selected points within grid.
34+ * The length of grid is at least 3*m.
35+ * @param[in,out] idx Indices of the selected points within grid.
36+ * On exit, the indices are rearranged such that
37+ * points in each subset are put together.
38+ *
39+ * @return The number of points in the first subset (according to idx).
3540 *
3641 */
37- int _maxmin_bisect (int m, const double * grid, int * idx) {
42+ int _maxmin_divide (int m, const double * grid, int * idx) {
3843 assert (m > 1 );
3944 if (m == 2 ) {
4045 return 1 ;
4146 }
4247
4348 std::vector<double > centroid (3 , 0.0 );
4449 for (int i = 0 ; i < m; ++i) {
45- int ig = idx[i];
46- centroid[0 ] += grid[3 *ig ];
47- centroid[1 ] += grid[3 *ig + 1 ];
48- centroid[2 ] += grid[3 *ig + 2 ];
50+ int j = idx[i];
51+ centroid[0 ] += grid[3 *j ];
52+ centroid[1 ] += grid[3 *j + 1 ];
53+ centroid[2 ] += grid[3 *j + 2 ];
4954 }
5055 centroid[0 ] /= m;
5156 centroid[1 ] /= m;
@@ -60,68 +65,66 @@ int _maxmin_bisect(int m, const double* grid, int* idx) {
6065 R[3 *i + 2 ] = grid[3 *j + 2 ] - centroid[2 ];
6166 }
6267
63- // A = R*R^T is a 3-by-3 matrix
68+ // The normal vector of the cut plane is taken to be the eigenvector
69+ // corresponding to the largest eigenvalue of the 3x3 matrix A = R*R^T.
6470 std::vector<double > A (9 , 0.0 );
6571 int i3 = 3 , i1 = 1 ;
6672 double d0 = 0.0 , d1 = 1.0 ;
6773 dsyrk_ (" U" , " N" , &i3, &m, &d1, R.data (), &i3, &d0, A.data (), &i3);
6874
69- // eigenpairs of A
7075 int info = 0 , lwork = 102 /* determined by a work space query */ ;
7176 std::vector<double > e (3 ), work (lwork);
7277 dsyev_ (" V" , " U" , &i3, A.data (), &i3, e.data (), work.data (), &lwork, &info);
78+ double * n = A.data () + 6 ; // normal vector of the cut plane
7379
74- // normal vector of the cut plane
75- // (eigenvector corresponding to the largest eigenvalue)
76- double * n = A.data () + 6 ;
77-
78- // (signed) distance w.r.t. the cut plane
80+ // Rearrange the indices to put points in each subset together by
81+ // examining the signed distances of the points to the cut plane (R^T*n).
7982 std::vector<double > dist (m);
80- for (int i = 0 ; i < m; ++i) {
81- dist[i] = ddot_ (&i3, R.data () + 3 *i, &i1, n, &i1);
82- }
83+ dgemv_ (" T" , &i3, &m, &d1, R.data (), &i3, n, &i1, &d0, dist.data (), &i1);
8384
84- //
8585 int *head = idx;
8686 std::reverse_iterator<int *> tail (idx + m), rend (idx);
8787 auto is_negative = [&dist](int j) { return dist[j] < 0 ; };
8888 auto is_nonnegative = [&dist](int j) { return dist[j] >= 0 ; };
89- while ( (head = std::find (head, idx + m, is_negative)) <
90- (tail = std::find (tail, rend, is_nonnegative)).base ()) {
89+ while ( ( head = std::find (head, idx + m, is_negative) ) <
90+ ( tail = std::find (tail, rend, is_nonnegative) ).base () ) {
9191 std::swap (*head, *tail);
9292 std::swap (dist[head - idx], dist[tail.base () - idx]);
9393 ++head;
9494 ++tail;
9595 }
9696
97- return std::find (idx, idx + m, is_nonnegative ) - idx;
97+ return std::find (idx, idx + m, is_negative ) - idx;
9898}
9999
100100} // end of anonymous namespace
101101
102+
102103std::vector<int > Grid::Batch::maxmin (
103- int n_max ,
104- int n ,
104+ int m_max ,
105+ int m ,
105106 const double * grid,
106107 int * idx
107108) {
108- if (n <= n_max ) {
109+ if (m <= m_max ) {
109110 return std::vector<int >{};
110111 }
111112
112- int n_left = _maxmin_bisect (n , grid, idx);
113+ int m_left = _maxmin_divide (m , grid, idx);
113114
114- std::vector<int > delim_left = maxmin (n_max, n_left, grid, idx);
115- std::vector<int > delim_right = maxmin (n_max, n - n_left, grid + n_left, idx + n_left);
116- std::for_each (delim_right.begin (), delim_right.end (),
117- [n_left](int & x) { x += n_left; }
115+ // recursively divide the subsets
116+ std::vector<int > left = maxmin (m_max, m_left, grid, idx);
117+ std::vector<int > right = maxmin (m_max, m - m_left, grid, idx + m_left);
118+ std::for_each (right.begin (), right.end (),
119+ [m_left](int & x) { x += m_left; }
118120 );
119121
120- // merge all delimiters into delim_left
121- delim_left.resize (delim_left.size () + delim_right.size () + 1 );
122- delim_left[delim_left.size ()] = n_left;
123- std::copy (delim_right.begin (), delim_right.end (), delim_left.begin () + delim_left.size () + 1 );
124- return delim_left;
122+ // merge all delimiters
123+ int sz_left = left.size ();
124+ left.resize (sz_left + right.size () + 1 );
125+ left[sz_left] = m_left;
126+ std::copy (right.begin (), right.end (), left.begin () + sz_left + 1 );
127+ return left;
125128}
126129
127130
0 commit comments