Skip to content

Commit 4c8876e

Browse files
committed
more test
1 parent f8b00a2 commit 4c8876e

File tree

1 file changed

+93
-32
lines changed

1 file changed

+93
-32
lines changed

source/module_base/grid/test/test_batch.cpp

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,112 @@
1010

1111
using namespace Grid::Batch;
1212

13+
1314
class BatchTest: public ::testing::Test
1415
{
1516
protected:
16-
void SetUp();
1717

1818
std::vector<double> grid_;
1919
std::vector<int> idx_;
2020

21-
// parameters for cluster generation
22-
int n_each_ = 10;
23-
double width_ = 1.0;
24-
25-
// These offsets should be different from each other as maxmin might
26-
// fail for highly symmetric, well-separated clusters.
21+
// parameters for octant clusters
22+
const int n_batch_oct_ = 10;
23+
const double width_oct_ = 1.0;
24+
const double offset_x_ = 7.0;
25+
const double offset_y_ = 8.0;
26+
const double offset_z_ = 9.0;
27+
// NOTE: These offsets should be different from each other as maxmin
28+
// might fail for highly symmetric, well-separated clusters.
2729
// Consider the case where the 8 clusters as a whole have octahedral
2830
// symmetry. In this case, R*R^T must be proprotional to the identity,
2931
// and eigenvalues are three-fold degenerate, because xy, yz and zx
3032
// plane are equivalent in terms of the maxmin optimization problem.
3133
// This means eigenvectors are arbitrary in this case.
32-
double offset_x_ = 7.0;
33-
double offset_y_ = 8.0;
34-
double offset_z_ = 9.0;
34+
35+
// parameters for random cluster
36+
const int n_grid_rand_ = 10000;
37+
const int n_batch_rand_ = 100;
38+
const double width_rand_ = 20.0;
39+
const double xc_ = 5.0;
40+
const double yc_ = 5.0;
41+
const double zc_ = 7.0;
3542
};
3643

37-
std::vector<double> gen_octant_cluster(int n_each, double offset_x, double offset_y, double offset_z, double width) {
44+
45+
void gen_random(
46+
int ngrid,
47+
double xc,
48+
double yc,
49+
double zc,
50+
double width,
51+
std::vector<double>& grid,
52+
std::vector<int>& idx
53+
) {
54+
55+
// Generates a set of points centered around (xc, yc, zc).
56+
57+
std::random_device rd;
58+
std::mt19937 gen(rd());
59+
std::uniform_real_distribution<double> dis(-width, width);
60+
61+
grid.resize(3 * ngrid);
62+
for (int i = 0; i < ngrid; ++i) {
63+
grid[3*i ] = xc + dis(gen);
64+
grid[3*i + 1] = yc + dis(gen);
65+
grid[3*i + 2] = zc + dis(gen);
66+
}
67+
68+
idx.resize(ngrid);
69+
std::iota(idx.begin(), idx.end(), 0);
70+
std::shuffle(idx.begin(), idx.end(), gen);
71+
}
72+
73+
74+
void gen_octant(
75+
int n_each,
76+
double offset_x,
77+
double offset_y,
78+
double offset_z,
79+
double width,
80+
std::vector<double>& grid,
81+
std::vector<int>& idx
82+
) {
3883

3984
// Generates a set of points consisting of 8 well-separated, equal-sized
4085
// clusters located in individual octants.
4186

42-
std::vector<double> grid(n_each * 8 * 3);
43-
int I = 0;
44-
4587
std::random_device rd;
4688
std::mt19937 gen(rd());
4789
std::uniform_real_distribution<double> dis(-width, width);
4890

91+
int ngrid = 8 * n_each;
92+
grid.resize(3 * ngrid);
93+
int I = 0;
4994
for (int sign_x : {-1, 1}) {
5095
for (int sign_y : {-1, 1}) {
5196
for (int sign_z : {-1, 1}) {
52-
for (int i = 0; i < n_each; ++i) {
97+
for (int i = 0; i < n_each; ++i, ++I) {
5398
grid[3*I ] = sign_x * offset_x + dis(gen);
5499
grid[3*I + 1] = sign_y * offset_y + dis(gen);
55100
grid[3*I + 2] = sign_z * offset_z + dis(gen);
56-
++I;
57101
}
58102
}
59103
}
60104
}
61-
return grid;
105+
106+
idx.resize(ngrid);
107+
std::iota(idx.begin(), idx.end(), 0);
108+
std::shuffle(idx.begin(), idx.end(), gen);
62109
}
63110

111+
64112
bool is_same_octant(int ngrid, const double* grid) {
65113
if (ngrid == 0) {
66114
return true;
67115
}
68-
bool is_positive_x = grid[0] > 0;
69-
bool is_positive_y = grid[1] > 0;
70-
bool is_positive_z = grid[2] > 0;
116+
const bool is_positive_x = grid[0] > 0;
117+
const bool is_positive_y = grid[1] > 0;
118+
const bool is_positive_z = grid[2] > 0;
71119
const double* end = grid + 3 * ngrid;
72120
for (; grid != end; grid += 3) {
73121
if ( is_positive_x != (grid[0] > 0) ||
@@ -80,16 +128,26 @@ bool is_same_octant(int ngrid, const double* grid) {
80128
}
81129

82130

83-
void BatchTest::SetUp()
131+
TEST_F(BatchTest, MaxMinRandomCluster)
84132
{
85-
grid_ = gen_octant_cluster(n_each_, offset_x_, offset_y_, offset_z_, width_);
133+
// This test verifies that the sizes of batches produced by maxmin
134+
// do not exceed the specified limit.
86135

87-
idx_.resize(grid_.size() / 3);
88-
std::iota(idx_.begin(), idx_.end(), 0);
136+
gen_random(n_grid_rand_, xc_, yc_, zc_, width_rand_, grid_, idx_);
89137

90-
std::random_device rd;
91-
std::mt19937 g(rd());
92-
std::shuffle(idx_.begin(), idx_.end(), g);
138+
std::vector<int> delim =
139+
maxmin(grid_.data(), idx_.data(), idx_.size(), n_batch_rand_);
140+
141+
for (size_t i = 0; i < delim.size(); ++i) {
142+
if (i == 0) {
143+
EXPECT_EQ(delim[i], 0);
144+
} else {
145+
int sz_batch = delim[i] - delim[i-1];
146+
EXPECT_GT(sz_batch, 0);
147+
EXPECT_LE(sz_batch, n_batch_rand_);
148+
}
149+
}
150+
EXPECT_LE(idx_.size() - delim.back(), n_batch_rand_);
93151
}
94152

95153

@@ -99,26 +157,29 @@ TEST_F(BatchTest, MaxMinOctantCluster)
99157
// well-separated, equal-sized clusters located in individual octants.
100158
// The resulting batches should be able to recover this structure.
101159

160+
gen_octant(n_batch_oct_, offset_x_, offset_y_, offset_z_, width_oct_,
161+
grid_, idx_);
162+
102163
std::vector<int> delim =
103-
maxmin(grid_.data(), idx_.data(), grid_.size() / 3, n_each_);
164+
maxmin(grid_.data(), idx_.data(), idx_.size(), n_batch_oct_);
104165

105166
EXPECT_EQ(delim.size(), 8);
106167

107-
std::vector<double> grid_batch(3 * n_each_);
168+
std::vector<double> grid_batch(3 * n_batch_oct_);
108169
for (int i = 0; i < 8; ++i) {
109170

110-
EXPECT_EQ(delim[i], i * n_each_);
171+
EXPECT_EQ(delim[i], i * n_batch_oct_);
111172

112173
// collect points within the present batch
113-
for (int j = 0; j < n_each_; ++j) {
174+
for (int j = 0; j < n_batch_oct_; ++j) {
114175
int ig = idx_[delim[i] + j];
115176
grid_batch[3*j ] = grid_[3*ig ];
116177
grid_batch[3*j + 1] = grid_[3*ig + 1];
117178
grid_batch[3*j + 2] = grid_[3*ig + 2];
118179
}
119180

120181
// verify that points in a batch reside in the same octant
121-
EXPECT_TRUE(is_same_octant(n_each_, grid_batch.data()));
182+
EXPECT_TRUE(is_same_octant(n_batch_oct_, grid_batch.data()));
122183
}
123184
}
124185

0 commit comments

Comments
 (0)