Skip to content

Commit 16a26a3

Browse files
committed
Add dimension parameter for BPCG method
1 parent 706633f commit 16a26a3

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ DiagoBPCG<T, Device>::~DiagoBPCG() {
3030
}
3131

3232
template<typename T, typename Device>
33-
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
33+
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis, const int ndim) {
3434
// Specify the problem size n_basis, n_band, while lda is n_basis
3535
this->n_band = nband;
3636
this->n_basis = nbasis;
37-
37+
this->n_dim = ndim;
3838

3939
// All column major tensors
4040

source/module_hsolver/diago_bpcg.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class DiagoBPCG
5252
*
5353
* @param nband The number of bands.
5454
* @param nbasis The number of basis functions. Leading dimension of psi.
55+
* @param ndim The number of valid dimension of psi.
5556
*/
56-
void init_iter(const int nband, const int nbasis);
57+
void init_iter(const int nband, const int nbasis, const int ndim);
5758

5859
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5960

@@ -77,6 +78,8 @@ class DiagoBPCG
7778
int n_band = 0;
7879
/// the number of cols of the input psi
7980
int n_basis = 0;
81+
/// valid dimension of psi
82+
int n_dim = 0;
8083
/// max iter steps for all-band cg loop
8184
int nline = 4;
8285

@@ -332,6 +335,7 @@ class DiagoBPCG
332335

333336
using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
334337
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
338+
using gemm_op = hsolver::gemm_op<T, Device>;
335339

336340
};
337341

0 commit comments

Comments
 (0)