Skip to content

Commit c898e52

Browse files
authored
Fix: use gemm instead of einsum in BPCG (#5827)
* Add dimension parameter for BPCG method * Add utils for hsovler gemm_op * Change code to fit new bpcg init interface * using gemm instead of einsum in orth_cholesky * using gemm instead of einsum in orth_projection * replace einsum by gemm in orth_projection * replace einsum by gemm in rotate_wf * replace einsum by gemm in diag_hsub * Update 102_PW_BPCG totalstressref reference value
1 parent 24abddd commit c898e52

File tree

5 files changed

+110
-10
lines changed

5 files changed

+110
-10
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2222
this->device_type = ct::DeviceTypeToEnum<Device>::value;
2323

2424
this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis}));
25+
26+
this->one = &one_;
27+
this->zero = &zero_;
28+
this->neg_one = &neg_one_;
2529
}
2630

2731
template<typename T, typename Device>
@@ -30,11 +34,11 @@ DiagoBPCG<T, Device>::~DiagoBPCG() {
3034
}
3135

3236
template<typename T, typename Device>
33-
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
37+
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis, const int ndim) {
3438
// Specify the problem size n_basis, n_band, while lda is n_basis
3539
this->n_band = nband;
3640
this->n_basis = nbasis;
37-
41+
this->n_dim = ndim;
3842

3943
// All column major tensors
4044

@@ -93,7 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9397
// hsub_out = psi_out * transc(psi_out)
9498
ct::EinsumOption option(
9599
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
96-
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
100+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
101+
102+
// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
103+
gemm_op()(this->ctx,
104+
'C',
105+
'N',
106+
this->n_band, //m
107+
this->n_band, //n
108+
this->n_dim, //k
109+
this->one, //1.0
110+
psi_out.data<T>(),
111+
this->n_basis, //lda
112+
psi_out.data<T>(),
113+
this->n_basis, //ldb
114+
this->zero, //0.0
115+
hsub_out.data<T>(),
116+
this->n_band); //ldc
97117

98118
// set hsub matrix to lower format;
99119
ct::kernels::set_matrix<T, ct_Device>()(
@@ -145,12 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
145165
{
146166
ct::EinsumOption option(
147167
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
148-
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
168+
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169+
170+
// this->orth_projection(this->psi, this->hsub, this->grad);
171+
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172+
gemm_op()(this->ctx,
173+
'C',
174+
'N',
175+
this->n_band, //m
176+
this->n_band, //n
177+
this->n_dim, //k
178+
this->one, //1.0
179+
psi_in.data<T>(),
180+
this->n_basis, //lda
181+
grad_out.data<T>(),
182+
this->n_basis, //ldb
183+
this->zero, //0.0
184+
hsub_in.data<T>(),
185+
this->n_band); //ldc
149186

150187
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
151188
option = ct::EinsumOption(
152189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
153-
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
190+
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191+
192+
// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193+
gemm_op()(this->ctx,
194+
'N',
195+
'N',
196+
this->n_dim, //m
197+
this->n_band, //n
198+
this->n_band, //k
199+
this->neg_one, //-1.0
200+
psi_in.data<T>(),
201+
this->n_basis, //lda
202+
hsub_in.data<T>(),
203+
this->n_band, //ldb
204+
this->one, //1.0
205+
grad_out.data<T>(),
206+
this->n_basis); //ldc
154207

155208
return;
156209
}
@@ -165,6 +218,24 @@ void DiagoBPCG<T, Device>::rotate_wf(
165218
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
166219
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
167220

221+
// this->rotate_wf(hsub_out, psi_out, workspace_in);
222+
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223+
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224+
// gemm_op()(this->ctx,
225+
// 'N',
226+
// 'N',
227+
// this->n_basis, //m
228+
// this->n_band, //n
229+
// this->n_band, //k
230+
// this->one, //1.0
231+
// psi_out.data<T>(),
232+
// this->n_basis, //lda
233+
// hsub_in.data<T>(),
234+
// this->n_band, //ldb
235+
// this->zero, //0.0
236+
// workspace_in.data<T>(),
237+
// this->n_basis); //ldc
238+
168239
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);
169240

170241
return;
@@ -192,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
192263
// it controls the ops to use the corresponding device to calculate results
193264
ct::EinsumOption option(
194265
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
195-
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
266+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267+
268+
// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269+
gemm_op()(this->ctx,
270+
'C',
271+
'N',
272+
this->n_band, //m
273+
this->n_band, //n
274+
this->n_dim, //k
275+
this->one, //1.0
276+
hpsi_in.data<T>(),
277+
this->n_basis, //lda
278+
psi_in.data<T>(),
279+
this->n_basis, //ldb
280+
this->zero, //0.0
281+
hsub_out.data<T>(),
282+
this->n_band); //ldc
196283

197284
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
198285

source/module_hsolver/diago_bpcg.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class DiagoBPCG
5252
*
5353
* @param nband The number of bands.
5454
* @param nbasis The number of basis functions. Leading dimension of psi.
55+
* @param ndim The number of valid dimension of psi.
5556
*/
56-
void init_iter(const int nband, const int nbasis);
57+
void init_iter(const int nband, const int nbasis, const int ndim);
5758

5859
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5960

@@ -77,6 +78,8 @@ class DiagoBPCG
7778
int n_band = 0;
7879
/// the number of cols of the input psi
7980
int n_basis = 0;
81+
/// valid dimension of psi
82+
int n_dim = 0;
8083
/// max iter steps for all-band cg loop
8184
int nline = 4;
8285

@@ -107,6 +110,13 @@ class DiagoBPCG
107110
/// work for some calculations within this class, including rotate_wf call
108111
ct::Tensor work = {};
109112

113+
// These are for hsolver gemm_op use
114+
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
115+
Device * ctx = {};
116+
// Pointer to objects of 1 and 0 for gemm
117+
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
118+
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);
119+
110120
/**
111121
* @brief Update the precondition array.
112122
*
@@ -332,6 +342,7 @@ class DiagoBPCG
332342

333343
using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
334344
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
345+
using gemm_op = hsolver::gemm_op<T, Device>;
335346

336347
};
337348

source/module_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
483483
{
484484
const int nband = psi.get_nbands();
485485
const int nbasis = psi.get_nbasis();
486+
const int ndim = psi.get_current_ngk();
486487
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
487488
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
488489
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
@@ -499,7 +500,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
499500
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
500501
};
501502
DiagoBPCG<T, Device> bpcg(pre_condition.data());
502-
bpcg.init_iter(nband, nbasis);
503+
bpcg.init_iter(nband, nbasis, ndim);
503504
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
504505
}
505506
else if (this->method == "dav_subspace")

source/module_hsolver/test/diago_bpcg_test.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ class DiagoBPCGPrepare
153153
zero_,
154154
hpsi_out, ld_psi);
155155
};
156-
bpcg.init_iter(nband, npw);
156+
const int ndim = psi_local.get_current_ngk();
157+
bpcg.init_iter(nband, npw, ndim);
157158
std::vector<double> ethr_band(nband, 1e-5);
158159
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
159160
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);

tests/integrate/102_PW_BPCG/result.ref

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
etotref -4869.74705201
22
etotperatomref -2434.87352600
33
totalforceref 5.19483000
4-
totalstressref 37241.44843500
4+
totalstressref 37241.45334600
55
pointgroupref C_1
66
spacegroupref C_1
77
nksibzref 8

0 commit comments

Comments
 (0)