Skip to content

Commit e667622

Browse files
committed
Refactor: Remove bpcg dependency on Psi and Hamilt
1 parent 7198ec1 commit e667622

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2727
template<typename T, typename Device>
2828
DiagoBPCG<T, Device>::~DiagoBPCG() {
2929
// Note, we do not need to free the h_prec and psi pointer as they are refs to the outside data
30-
delete this->grad_wrapper;
30+
// delete this->grad_wrapper;
3131
}
3232

3333
template<typename T, typename Device>
34-
void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
34+
void DiagoBPCG<T, Device>::init_iter(/*const T *psi_in,*/ const int nband, const int nbasis) {
3535
// Specify the problem size n_basis, n_band, while lda is n_basis
36-
this->n_band = psi_in.get_nbands();
37-
this->n_basis = psi_in.get_nbasis();
36+
// this->n_band = psi_in.get_nbands();
37+
// this->n_basis = psi_in.get_nbasis();
38+
this->n_band = nband;
39+
this->n_basis = nbasis;
40+
3841

3942
// All column major tensors
4043

@@ -52,8 +55,9 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
5255
this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis}));
5356

5457
//TODO: Remove class Psi, using ct::Tensor instead!
55-
this->grad_wrapper = new psi::Psi<T, Device>(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer());
56-
this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
58+
// this->grad_wrapper = new psi::Psi<T, Device>(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer());
59+
// this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
60+
this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis}));
5761
}
5862

5963
template<typename T, typename Device>
@@ -174,16 +178,19 @@ void DiagoBPCG<T, Device>::rotate_wf(
174178

175179
template<typename T, typename Device>
176180
void DiagoBPCG<T, Device>::calc_hpsi_with_block(
177-
hamilt::Hamilt<T, Device>* hamilt_in,
178-
const psi::Psi<T, Device>& psi_in,
181+
// hamilt::Hamilt<T, Device>* hamilt_in,
182+
const HPsiFunc& hpsi_func,
183+
// const psi::Psi<T, Device>& psi_in,
184+
T *psi_in,
179185
ct::Tensor& hpsi_out)
180186
{
181187
// calculate all-band hpsi
182-
psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1);
183-
hpsi_info info(&psi_in, all_bands_range, hpsi_out.data<T>());
184-
hamilt_in->ops->hPsi(info);
188+
// psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1);
189+
// hpsi_info info(&psi_in, all_bands_range, hpsi_out.data<T>());
190+
// hamilt_in->ops->hPsi(info);
191+
hpsi_func(psi_in, hpsi_out.data<T>(), this->n_basis, this->n_band);
185192

186-
return;
193+
// return;
187194
}
188195

189196
template<typename T, typename Device>
@@ -207,16 +214,18 @@ void DiagoBPCG<T, Device>::diag_hsub(
207214

208215
template<typename T, typename Device>
209216
void DiagoBPCG<T, Device>::calc_hsub_with_block(
210-
hamilt::Hamilt<T, Device> *hamilt_in,
211-
const psi::Psi<T, Device> &psi_in,
217+
// hamilt::Hamilt<T, Device>* hamilt_in,
218+
const HPsiFunc& hpsi_func,
219+
// const psi::Psi<T, Device>& psi_in,
220+
T *psi_in,
212221
ct::Tensor& psi_out,
213222
ct::Tensor& hpsi_out,
214223
ct::Tensor& hsub_out,
215224
ct::Tensor& workspace_in,
216225
ct::Tensor& eigenvalue_out)
217226
{
218227
// Apply the H operator to psi and obtain the hpsi matrix.
219-
this->calc_hpsi_with_block(hamilt_in, psi_in, hpsi_out);
228+
this->calc_hpsi_with_block(hpsi_func, psi_in, hpsi_out);
220229

221230
// Diagonalization of the subspace matrix.
222231
this->diag_hsub(psi_out,hpsi_out, hsub_out, eigenvalue_out);
@@ -250,19 +259,21 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
250259

251260
template<typename T, typename Device>
252261
void DiagoBPCG<T, Device>::diag(
253-
hamilt::Hamilt<T, Device>* hamilt_in,
254-
psi::Psi<T, Device>& psi_in,
262+
// hamilt::Hamilt<T, Device>* hamilt_in,
263+
const HPsiFunc& hpsi_func,
264+
// psi::Psi<T, Device>& psi_in,
265+
T *psi_in,
255266
Real* eigenvalue_in)
256267
{
257268
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
258269
// Get the pointer of the input psi
259-
this->psi = std::move(ct::TensorMap(psi_in.get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
270+
this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis}));
260271

261272
// Update the precondition array
262273
this->calc_prec();
263274

264275
// Improving the initial guess of the wave function psi through a subspace diagonalization.
265-
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
276+
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
266277

267278
setmem_complex_op()(this->grad_old.template data<T>(), 0, this->n_basis * this->n_band);
268279

@@ -293,7 +304,7 @@ void DiagoBPCG<T, Device>::diag(
293304
syncmem_complex_op()(this->grad_old.template data<T>(), this->grad.template data<T>(), n_basis * n_band);
294305

295306
// Calculate H|grad> matrix
296-
this->calc_hpsi_with_block(hamilt_in, this->grad_wrapper[0], this->hgrad);
307+
this->calc_hpsi_with_block(hpsi_func, this->grad.data<T>(), /*this->grad_wrapper[0],*/ this->hgrad);
297308

298309
// optimize psi as well as the hpsi
299310
// 1. normalize grad
@@ -305,7 +316,7 @@ void DiagoBPCG<T, Device>::diag(
305316
this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
306317

307318
if (current_scf_iter == 1 && ntry % this->nline == 0) {
308-
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
319+
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
309320
}
310321
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));
311322

source/module_hsolver/diago_bpcg.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class DiagoBPCG
5252
*
5353
* @param psi_in The input wavefunction psi.
5454
*/
55-
void init_iter(const psi::Psi<T, Device> &psi_in);
55+
void init_iter(/*const T *psi_in,*/ const int nband, const int nbasis);
56+
57+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5658

5759
/**
5860
* @brief Diagonalize the Hamiltonian using the BPCG method.
@@ -63,7 +65,10 @@ class DiagoBPCG
6365
* @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major].
6466
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
6567
*/
66-
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in);
68+
void diag(// hamilt::Hamilt<T, Device> *phm_in,
69+
const HPsiFunc& hpsi_func,
70+
// psi::Psi<T, Device>& psi_in,
71+
T *psi_in, Real *eigenvalue_in);
6772

6873

6974
private:
@@ -139,8 +144,10 @@ class DiagoBPCG
139144
* @param hpsi_out Pointer to the array where the resulting hpsi matrix will be stored.
140145
*/
141146
void calc_hpsi_with_block(
142-
hamilt::Hamilt<T, Device>* hamilt_in,
143-
const psi::Psi<T, Device>& psi_in,
147+
// hamilt::Hamilt<T, Device>* hamilt_in,
148+
const HPsiFunc& hpsi_func,
149+
// const psi::Psi<T, Device>& psi_in,
150+
T *psi_in,
144151
ct::Tensor& hpsi_out);
145152

146153
/**
@@ -228,8 +235,10 @@ class DiagoBPCG
228235
* @param eigenvalue_out Computed eigen.
229236
*/
230237
void calc_hsub_with_block(
231-
hamilt::Hamilt<T, Device>* hamilt_in,
232-
const psi::Psi<T, Device>& psi_in,
238+
// hamilt::Hamilt<T, Device>* hamilt_in,
239+
const HPsiFunc& hpsi_func,
240+
// const psi::Psi<T, Device>& psi_in,
241+
T *psi_in,
233242
ct::Tensor& psi_out, ct::Tensor& hpsi_out,
234243
ct::Tensor& hsub_out, ct::Tensor& workspace_in,
235244
ct::Tensor& eigenvalue_out);

source/module_hsolver/hsolver_pw.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,29 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
467467
}
468468
else if (this->method == "bpcg")
469469
{
470+
const int nband = psi.get_nbands();
471+
const int nbasis = psi.get_nbasis();
472+
auto ngk_pointer = psi.get_ngk_pointer();
473+
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
474+
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
475+
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
476+
477+
// Convert "pointer data stucture" to a psi::Psi object
478+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
479+
480+
psi::Range bands_range(true, 0, 0, nvec - 1);
481+
482+
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
483+
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
484+
hm->ops->hPsi(info);
485+
486+
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
487+
};
470488
DiagoBPCG<T, Device> bpcg(pre_condition.data());
471-
bpcg.init_iter(psi);
472-
bpcg.diag(hm, psi, eigenvalue);
489+
// bpcg.init_iter(psi);
490+
bpcg.init_iter(nband, nbasis);
491+
// bpcg.diag(hm, psi, eigenvalue);
492+
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
473493
}
474494
else if (this->method == "dav_subspace")
475495
{

0 commit comments

Comments
 (0)