1010
1111using namespace Grid ::Batch;
1212
13+
1314class BatchTest : public ::testing::Test
1415{
1516protected:
16- void SetUp ();
1717
1818 std::vector<double > grid_;
1919 std::vector<int > idx_;
2020
21- // parameters for cluster generation
22- int n_each_ = 10 ;
23- double width_ = 1.0 ;
24-
25- // These offsets should be different from each other as maxmin might
26- // fail for highly symmetric, well-separated clusters.
21+ // parameters for octant clusters
22+ const int n_batch_oct_ = 10 ;
23+ const double width_oct_ = 1.0 ;
24+ const double offset_x_ = 7.0 ;
25+ const double offset_y_ = 8.0 ;
26+ const double offset_z_ = 9.0 ;
27+ // NOTE: These offsets should be different from each other as maxmin
28+ // might fail for highly symmetric, well-separated clusters.
2729 // Consider the case where the 8 clusters as a whole have octahedral
2830 // symmetry. In this case, R*R^T must be proprotional to the identity,
2931 // and eigenvalues are three-fold degenerate, because xy, yz and zx
3032 // plane are equivalent in terms of the maxmin optimization problem.
3133 // This means eigenvectors are arbitrary in this case.
32- double offset_x_ = 7.0 ;
33- double offset_y_ = 8.0 ;
34- double offset_z_ = 9.0 ;
34+
35+ // parameters for random cluster
36+ const int n_grid_rand_ = 10000 ;
37+ const int n_batch_rand_ = 100 ;
38+ const double width_rand_ = 20.0 ;
39+ const double xc_ = 5.0 ;
40+ const double yc_ = 5.0 ;
41+ const double zc_ = 7.0 ;
3542};
3643
37- std::vector<double > gen_octant_cluster (int n_each, double offset_x, double offset_y, double offset_z, double width) {
44+
45+ void gen_random (
46+ int ngrid,
47+ double xc,
48+ double yc,
49+ double zc,
50+ double width,
51+ std::vector<double >& grid,
52+ std::vector<int >& idx
53+ ) {
54+
55+ // Generates a set of points centered around (xc, yc, zc).
56+
57+ std::random_device rd;
58+ std::mt19937 gen (rd ());
59+ std::uniform_real_distribution<double > dis (-width, width);
60+
61+ grid.resize (3 * ngrid);
62+ for (int i = 0 ; i < ngrid; ++i) {
63+ grid[3 *i ] = xc + dis (gen);
64+ grid[3 *i + 1 ] = yc + dis (gen);
65+ grid[3 *i + 2 ] = zc + dis (gen);
66+ }
67+
68+ idx.resize (ngrid);
69+ std::iota (idx.begin (), idx.end (), 0 );
70+ std::shuffle (idx.begin (), idx.end (), gen);
71+ }
72+
73+
74+ void gen_octant (
75+ int n_each,
76+ double offset_x,
77+ double offset_y,
78+ double offset_z,
79+ double width,
80+ std::vector<double >& grid,
81+ std::vector<int >& idx
82+ ) {
3883
3984 // Generates a set of points consisting of 8 well-separated, equal-sized
4085 // clusters located in individual octants.
4186
42- std::vector<double > grid (n_each * 8 * 3 );
43- int I = 0 ;
44-
4587 std::random_device rd;
4688 std::mt19937 gen (rd ());
4789 std::uniform_real_distribution<double > dis (-width, width);
4890
91+ int ngrid = 8 * n_each;
92+ grid.resize (3 * ngrid);
93+ int I = 0 ;
4994 for (int sign_x : {-1 , 1 }) {
5095 for (int sign_y : {-1 , 1 }) {
5196 for (int sign_z : {-1 , 1 }) {
52- for (int i = 0 ; i < n_each; ++i) {
97+ for (int i = 0 ; i < n_each; ++i, ++I ) {
5398 grid[3 *I ] = sign_x * offset_x + dis (gen);
5499 grid[3 *I + 1 ] = sign_y * offset_y + dis (gen);
55100 grid[3 *I + 2 ] = sign_z * offset_z + dis (gen);
56- ++I;
57101 }
58102 }
59103 }
60104 }
61- return grid;
105+
106+ idx.resize (ngrid);
107+ std::iota (idx.begin (), idx.end (), 0 );
108+ std::shuffle (idx.begin (), idx.end (), gen);
62109}
63110
111+
64112bool is_same_octant (int ngrid, const double * grid) {
65113 if (ngrid == 0 ) {
66114 return true ;
67115 }
68- bool is_positive_x = grid[0 ] > 0 ;
69- bool is_positive_y = grid[1 ] > 0 ;
70- bool is_positive_z = grid[2 ] > 0 ;
116+ const bool is_positive_x = grid[0 ] > 0 ;
117+ const bool is_positive_y = grid[1 ] > 0 ;
118+ const bool is_positive_z = grid[2 ] > 0 ;
71119 const double * end = grid + 3 * ngrid;
72120 for (; grid != end; grid += 3 ) {
73121 if ( is_positive_x != (grid[0 ] > 0 ) ||
@@ -80,16 +128,26 @@ bool is_same_octant(int ngrid, const double* grid) {
80128}
81129
82130
83- void BatchTest::SetUp ( )
131+ TEST_F (BatchTest, MaxMinRandomCluster )
84132{
85- grid_ = gen_octant_cluster (n_each_, offset_x_, offset_y_, offset_z_, width_);
133+ // This test verifies that the sizes of batches produced by maxmin
134+ // do not exceed the specified limit.
86135
87- idx_.resize (grid_.size () / 3 );
88- std::iota (idx_.begin (), idx_.end (), 0 );
136+ gen_random (n_grid_rand_, xc_, yc_, zc_, width_rand_, grid_, idx_);
89137
90- std::random_device rd;
91- std::mt19937 g (rd ());
92- std::shuffle (idx_.begin (), idx_.end (), g);
138+ std::vector<int > delim =
139+ maxmin (grid_.data (), idx_.data (), idx_.size (), n_batch_rand_);
140+
141+ for (size_t i = 0 ; i < delim.size (); ++i) {
142+ if (i == 0 ) {
143+ EXPECT_EQ (delim[i], 0 );
144+ } else {
145+ int sz_batch = delim[i] - delim[i-1 ];
146+ EXPECT_GT (sz_batch, 0 );
147+ EXPECT_LE (sz_batch, n_batch_rand_);
148+ }
149+ }
150+ EXPECT_LE (idx_.size () - delim.back (), n_batch_rand_);
93151}
94152
95153
@@ -99,26 +157,29 @@ TEST_F(BatchTest, MaxMinOctantCluster)
99157 // well-separated, equal-sized clusters located in individual octants.
100158 // The resulting batches should be able to recover this structure.
101159
160+ gen_octant (n_batch_oct_, offset_x_, offset_y_, offset_z_, width_oct_,
161+ grid_, idx_);
162+
102163 std::vector<int > delim =
103- maxmin (grid_.data (), idx_.data (), grid_ .size () / 3 , n_each_ );
164+ maxmin (grid_.data (), idx_.data (), idx_ .size (), n_batch_oct_ );
104165
105166 EXPECT_EQ (delim.size (), 8 );
106167
107- std::vector<double > grid_batch (3 * n_each_ );
168+ std::vector<double > grid_batch (3 * n_batch_oct_ );
108169 for (int i = 0 ; i < 8 ; ++i) {
109170
110- EXPECT_EQ (delim[i], i * n_each_ );
171+ EXPECT_EQ (delim[i], i * n_batch_oct_ );
111172
112173 // collect points within the present batch
113- for (int j = 0 ; j < n_each_ ; ++j) {
174+ for (int j = 0 ; j < n_batch_oct_ ; ++j) {
114175 int ig = idx_[delim[i] + j];
115176 grid_batch[3 *j ] = grid_[3 *ig ];
116177 grid_batch[3 *j + 1 ] = grid_[3 *ig + 1 ];
117178 grid_batch[3 *j + 2 ] = grid_[3 *ig + 2 ];
118179 }
119180
120181 // verify that points in a batch reside in the same octant
121- EXPECT_TRUE (is_same_octant (n_each_ , grid_batch.data ()));
182+ EXPECT_TRUE (is_same_octant (n_batch_oct_ , grid_batch.data ()));
122183 }
123184}
124185
0 commit comments