Skip to content

Commit 0369832

Browse files
committed
continue...
1 parent d3ee813 commit 0369832

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

source/module_base/blas_connector.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ extern "C"
3939
double dnrm2_( const int *n, const double *X, const int *incX );
4040
double dznrm2_( const int *n, const std::complex<double> *X, const int *incX );
4141

42+
// symmetric rank-k update
43+
void dsyrk_(
44+
const char* uplo,
45+
const char* trans,
46+
const int* n,
47+
const int* k,
48+
const double* alpha,
49+
const double* a,
50+
const int* lda,
51+
const double* beta,
52+
double* c,
53+
const int* ldc
54+
);
55+
4256
// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
4357
void sgemv_(const char*const transa, const int*const m, const int*const n,
4458
const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx,
@@ -267,4 +281,4 @@ void zgemv_i(const char *trans,
267281
*/
268282

269283
#endif // GATHER_INFO
270-
#endif // BLAS_CONNECTOR_H
284+
#endif // BLAS_CONNECTOR_H

source/module_base/grid/batch.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,15 @@ int _maxmin_divide(int m, const double* grid, int* idx) {
7878
double* n = A.data() + 6; // normal vector of the cut plane
7979

8080
// Rearrange the indices to put points in each subset together by
81-
// examining the signed distances of the points to the cut plane (R^T*n).
81+
// examining the signed distances of points to the cut plane (R^T*n).
8282
std::vector<double> dist(m);
8383
dgemv_("T", &i3, &m, &d1, R.data(), &i3, n, &i1, &d0, dist.data(), &i1);
8484

8585
int *head = idx;
8686
std::reverse_iterator<int*> tail(idx + m), rend(idx);
8787
auto is_negative = [&dist](int j) { return dist[j] < 0; };
88-
auto is_nonnegative = [&dist](int j) { return dist[j] >= 0; };
89-
while ( ( head = std::find(head, idx + m, is_negative) ) <
90-
( tail = std::find(tail, rend, is_nonnegative) ).base() ) {
88+
while ( ( head = std::find_if(head, idx + m, is_negative) ) <
89+
( tail = std::find_if_not(tail, rend, is_negative) ).base() ) {
9190
std::swap(*head, *tail);
9291
std::swap(dist[head - idx], dist[tail.base() - idx]);
9392
++head;
@@ -101,29 +100,24 @@ int _maxmin_divide(int m, const double* grid, int* idx) {
101100

102101

103102
std::vector<int> Grid::Batch::maxmin(
104-
int m_max,
105-
int m,
106103
const double* grid,
107-
int* idx
104+
int* idx,
105+
int m,
106+
int m_thr
108107
) {
109-
if (m <= m_max) {
110-
return std::vector<int>{};
108+
if (m <= m_thr) {
109+
return std::vector<int>{0};
111110
}
112111

113112
int m_left = _maxmin_divide(m, grid, idx);
114113

115-
// recursively divide the subsets
116-
std::vector<int> left = maxmin(m_max, m_left, grid, idx);
117-
std::vector<int> right = maxmin(m_max, m - m_left, grid, idx + m_left);
114+
std::vector<int> left = maxmin(grid, idx, m_left, m_thr);
115+
std::vector<int> right = maxmin(grid, idx + m_left, m - m_left, m_thr);
118116
std::for_each(right.begin(), right.end(),
119117
[m_left](int& x) { x += m_left; }
120118
);
121119

122-
// merge all delimiters
123-
int sz_left = left.size();
124-
left.resize(sz_left + right.size() + 1);
125-
left[sz_left] = m_left;
126-
std::copy(right.begin(), right.end(), left.begin() + sz_left + 1);
120+
left.insert(left.end(), right.begin(), right.end());
127121
return left;
128122
}
129123

source/module_base/grid/batch.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,22 @@ namespace Batch {
99
/**
1010
* @brief Divide a set of points into batches by the "MaxMin" algorithm.
1111
*
12-
* This function recursively divides a given set of grid points into two
13-
* subsets by a cut plane using the "MaxMin" algorithm, until the number
14-
* of points in each subset is no more than n_max.
12+
* This function recursively uses a cut plane to divide a set of grid
13+
* points into two subsets using the "MaxMin" algorithm, until the
14+
* number of points in each subset (batch) is no more than m_thr.
15+
*
16+
* @param[in] grid Coordinates of all grid points.
17+
* @param[in,out] idx Indices of the initial set within grid.
18+
* On return, idx will be rearranged such
19+
* that points belonging to the same subset
20+
* are grouped together.
21+
* @param[in] m Number of points in the initial set.
22+
* @param[in] m_thr Size limit of subset.
23+
*
24+
* @return Indices (within idx) of the first point in each batch.
1525
*
1626
*/
17-
std::vector<int> maxmin(int n_max, int n, const double* grid, int* idx);
27+
std::vector<int> maxmin(const double* grid, int* idx, int m, int m_thr);
1828

1929
} // end of namespace Batch
2030
} // end of namespace Grid

0 commit comments

Comments
 (0)