Skip to content

Commit d3ee813

Browse files
committed
crude test...
1 parent 1a2ba2d commit d3ee813

File tree

3 files changed

+181
-43
lines changed

3 files changed

+181
-43
lines changed

source/module_base/grid/batch.cpp

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
#include "module_base/grid/batch.h"
22

3-
#include "module_base/blas_connector.h"
4-
#include "module_base/lapack_connector.h"
53
#include <algorithm>
64
#include <cassert>
75
#include <iterator>
86

7+
#include "module_base/blas_connector.h"
8+
#include "module_base/lapack_connector.h"
9+
910
namespace {
1011

1112
/**
12-
* @brief Bisect a set of points by the "MaxMin" algorithm.
13+
* @brief Divide a set of points into two subsets by the "MaxMin" algorithm.
1314
*
14-
* Given a selected set of grid points specified by the indices `idx`,
15-
* bisect this set by a cut plane {x|n^T*(x-c) = 0} where the normal
16-
* vector n and the point c are determined by the "MaxMin" problem:
15+
* This function divides a given set of grid points by a cut plane
16+
* {x|n^T*(x-c) = 0} where the normal vector n and the point c are
17+
* determined by the "MaxMin" problem:
1718
*
18-
* max min sum_{i=1}^{m} [n^T*(r[idx[i]] - c)]^2
19+
* max min sum_{i=1}^{m} [n^T * (r[idx[i]] - c)]^2
1920
* n c
2021
*
2122
* here r[j] = (grid[3*j], grid[3*j+1], grid[3*j+2]) is the position of
@@ -26,26 +27,30 @@ namespace {
2627
* of the matrix R*R^T, where R is the matrix whose i-th column is
2728
* r[idx[i]] - c.
2829
*
29-
* param[in] m Number of the selected points (size of idx).
30-
* param[in] grid Coordinates of all grid points.
30+
* @param[in] m Number of selected points (length of idx).
31+
* @param[in] grid Coordinates of all grid points.
3132
* grid[3*j], grid[3*j+1], grid[3*j+2] are the
3233
* x, y, z coordinates of the j-th point.
33-
* The size of grid is at least 3*m.
34-
* param[in,out] idx Indices of the selected points within grid.
34+
* The length of grid is at least 3*m.
35+
* @param[in,out] idx Indices of the selected points within grid.
36+
* On exit, the indices are rearranged such that
37+
* points in each subset are put together.
38+
*
39+
* @return The number of points in the first subset (according to idx).
3540
*
3641
*/
37-
int _maxmin_bisect(int m, const double* grid, int* idx) {
42+
int _maxmin_divide(int m, const double* grid, int* idx) {
3843
assert(m > 1);
3944
if (m == 2) {
4045
return 1;
4146
}
4247

4348
std::vector<double> centroid(3, 0.0);
4449
for (int i = 0; i < m; ++i) {
45-
int ig = idx[i];
46-
centroid[0] += grid[3*ig ];
47-
centroid[1] += grid[3*ig + 1];
48-
centroid[2] += grid[3*ig + 2];
50+
int j = idx[i];
51+
centroid[0] += grid[3*j ];
52+
centroid[1] += grid[3*j + 1];
53+
centroid[2] += grid[3*j + 2];
4954
}
5055
centroid[0] /= m;
5156
centroid[1] /= m;
@@ -60,68 +65,66 @@ int _maxmin_bisect(int m, const double* grid, int* idx) {
6065
R[3*i + 2] = grid[3*j + 2] - centroid[2];
6166
}
6267

63-
// A = R*R^T is a 3-by-3 matrix
68+
// The normal vector of the cut plane is taken to be the eigenvector
69+
// corresponding to the largest eigenvalue of the 3x3 matrix A = R*R^T.
6470
std::vector<double> A(9, 0.0);
6571
int i3 = 3, i1 = 1;
6672
double d0 = 0.0, d1 = 1.0;
6773
dsyrk_("U", "N", &i3, &m, &d1, R.data(), &i3, &d0, A.data(), &i3);
6874

69-
// eigenpairs of A
7075
int info = 0, lwork = 102 /* determined by a work space query */;
7176
std::vector<double> e(3), work(lwork);
7277
dsyev_("V", "U", &i3, A.data(), &i3, e.data(), work.data(), &lwork, &info);
78+
double* n = A.data() + 6; // normal vector of the cut plane
7379

74-
// normal vector of the cut plane
75-
// (eigenvector corresponding to the largest eigenvalue)
76-
double* n = A.data() + 6;
77-
78-
// (signed) distance w.r.t. the cut plane
80+
// Rearrange the indices to put points in each subset together by
81+
// examining the signed distances of the points to the cut plane (R^T*n).
7982
std::vector<double> dist(m);
80-
for (int i = 0; i < m; ++i) {
81-
dist[i] = ddot_(&i3, R.data() + 3*i, &i1, n, &i1);
82-
}
83+
dgemv_("T", &i3, &m, &d1, R.data(), &i3, n, &i1, &d0, dist.data(), &i1);
8384

84-
//
8585
int *head = idx;
8686
std::reverse_iterator<int*> tail(idx + m), rend(idx);
8787
auto is_negative = [&dist](int j) { return dist[j] < 0; };
8888
auto is_nonnegative = [&dist](int j) { return dist[j] >= 0; };
89-
while ( (head = std::find(head, idx + m, is_negative)) <
90-
(tail = std::find(tail, rend, is_nonnegative)).base()) {
89+
while ( ( head = std::find(head, idx + m, is_negative) ) <
90+
( tail = std::find(tail, rend, is_nonnegative) ).base() ) {
9191
std::swap(*head, *tail);
9292
std::swap(dist[head - idx], dist[tail.base() - idx]);
9393
++head;
9494
++tail;
9595
}
9696

97-
return std::find(idx, idx + m, is_nonnegative) - idx;
97+
return std::find(idx, idx + m, is_negative) - idx;
9898
}
9999

100100
} // end of anonymous namespace
101101

102+
102103
std::vector<int> Grid::Batch::maxmin(
103-
int n_max,
104-
int n,
104+
int m_max,
105+
int m,
105106
const double* grid,
106107
int* idx
107108
) {
108-
if (n <= n_max) {
109+
if (m <= m_max) {
109110
return std::vector<int>{};
110111
}
111112

112-
int n_left = _maxmin_bisect(n, grid, idx);
113+
int m_left = _maxmin_divide(m, grid, idx);
113114

114-
std::vector<int> delim_left = maxmin(n_max, n_left, grid, idx);
115-
std::vector<int> delim_right = maxmin(n_max, n - n_left, grid + n_left, idx + n_left);
116-
std::for_each(delim_right.begin(), delim_right.end(),
117-
[n_left](int& x) { x += n_left; }
115+
// recursively divide the subsets
116+
std::vector<int> left = maxmin(m_max, m_left, grid, idx);
117+
std::vector<int> right = maxmin(m_max, m - m_left, grid, idx + m_left);
118+
std::for_each(right.begin(), right.end(),
119+
[m_left](int& x) { x += m_left; }
118120
);
119121

120-
// merge all delimiters into delim_left
121-
delim_left.resize(delim_left.size() + delim_right.size() + 1);
122-
delim_left[delim_left.size()] = n_left;
123-
std::copy(delim_right.begin(), delim_right.end(), delim_left.begin() + delim_left.size() + 1);
124-
return delim_left;
122+
// merge all delimiters
123+
int sz_left = left.size();
124+
left.resize(sz_left + right.size() + 1);
125+
left[sz_left] = m_left;
126+
std::copy(right.begin(), right.end(), left.begin() + sz_left + 1);
127+
return left;
125128
}
126129

127130

source/module_base/grid/batch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
namespace Grid {
77
namespace Batch {
88

9+
/**
10+
* @brief Divide a set of points into batches by the "MaxMin" algorithm.
11+
*
12+
* This function recursively divides a given set of grid points into two
13+
* subsets by a cut plane using the "MaxMin" algorithm, until the number
14+
* of points in each subset is no more than n_max.
15+
*
16+
*/
917
std::vector<int> maxmin(int n_max, int n, const double* grid, int* idx);
1018

1119
} // end of namespace Batch
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "module_base/grid/batch.h"
2+
3+
#include "gtest/gtest.h"
4+
#include <algorithm>
5+
#include <random>
6+
7+
#ifdef __MPI
8+
#include <mpi.h>
9+
#endif
10+
11+
using namespace Grid::Batch;
12+
13+
14+
class BatchTest: public ::testing::Test
15+
{
16+
protected:
17+
void SetUp();
18+
19+
std::vector<double> grid_;
20+
std::vector<int> idx_;
21+
22+
int n_each_ = 10;
23+
double offset_ = 10.0;
24+
double width_ = 1.0;
25+
};
26+
27+
std::vector<double> gen_octant_cluster(int n_each, double offset, double width) {
28+
29+
// Generates a set of points consisting of 8 well-separated, equal-sized
30+
// clusters located in individual octants.
31+
32+
std::vector<double> grid(n_each * 8);
33+
int I = 0;
34+
35+
std::random_device rd;
36+
std::mt19937 gen(rd());
37+
std::uniform_real_distribution<double> dis(-width, width);
38+
39+
for (int sign_x : {-1, 1}) {
40+
for (int sign_y : {-1, 1}) {
41+
for (int sign_z : {-1, 1}) {
42+
for (int i = 0; i < n_each; ++i) {
43+
grid[3*I ] = sign_x * offset + dis(gen);
44+
grid[3*I + 1] = sign_y * offset + dis(gen);
45+
grid[3*I + 2] = sign_z * offset + dis(gen);
46+
++I;
47+
}
48+
}
49+
}
50+
}
51+
52+
return grid;
53+
}
54+
55+
bool is_same_octant(int ngrid, const double* grid) {
56+
if (ngrid == 0) {
57+
return true;
58+
}
59+
bool is_positive_x = grid[0] > 0;
60+
bool is_positive_y = grid[1] > 0;
61+
bool is_positive_z = grid[2] > 0;
62+
const double* end = grid + 3 * ngrid;
63+
for (; grid != end; grid += 3) {
64+
if ( is_positive_x != (grid[0] > 0) ||
65+
is_positive_y != (grid[1] > 0) ||
66+
is_positive_z != (grid[2] > 0) ) {
67+
return false;
68+
}
69+
}
70+
return true;
71+
}
72+
73+
74+
void BatchTest::SetUp()
75+
{
76+
grid_ = gen_octant_cluster(n_each_, offset_, width_);
77+
78+
idx_.resize(grid_.size());
79+
std::iota(idx_.begin(), idx_.end(), 0);
80+
81+
std::random_device rd;
82+
std::mt19937 g(rd());
83+
std::shuffle(idx_.begin(), idx_.end(), g);
84+
}
85+
86+
87+
TEST_F(BatchTest, MaxMinOctantCluster)
88+
{
89+
// This test applies maxmin to a set of points consisting of 8
90+
// well-separated, equal-sized clusters located in individual octants.
91+
// The resulting batches should be able to recover this structure.
92+
93+
std::vector<int> delim =
94+
maxmin(n_each_, grid_.size(), grid_.data(), idx_.data());
95+
96+
EXPECT_EQ(delim.size(), 7);
97+
for (int i = 0; i < 7; ++i) {
98+
// check number of points in each batch via index delimiters
99+
EXPECT_EQ(delim[i], (i+1) * n_each_);
100+
101+
// verify that points in each batch is in the same octant
102+
std::vector<double> batch(3 * n_each_);
103+
for (int j = 0; j < n_each_; ++j) {
104+
for (int k = 0; k < 3; ++k) {
105+
batch[3*j + k] = grid_[3*(i*n_each_ + j) + k];
106+
}
107+
}
108+
EXPECT_TRUE(is_same_octant(n_each_, batch.data()));
109+
}
110+
}
111+
112+
113+
int main(int argc, char** argv)
114+
{
115+
#ifdef __MPI
116+
MPI_Init(&argc, &argv);
117+
#endif
118+
119+
testing::InitGoogleTest(&argc, argv);
120+
int result = RUN_ALL_TESTS();
121+
122+
#ifdef __MPI
123+
MPI_Finalize();
124+
#endif
125+
126+
return result;
127+
}

0 commit comments

Comments
 (0)