Skip to content

Commit 5ab8ab8

Browse files
authored
Refactor: Remove bpcg dependency on Psi and Hamilt (#5643)
* Refactor: Remove bpcg dependency on Psi and Hamilt * Test: change bpcg tests to fit new interface * Test: make bpcg restart the same time as before in test * Tests: add the template disambiguator for dependent names for bpcg tests * Refactor: clean up bpcg * Docs: new BPCG interface * clean up bpcg code * clean useless code * clean useless code * clean useless code
1 parent bfe6364 commit 5ab8ab8

File tree

4 files changed

+80
-67
lines changed

4 files changed

+80
-67
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2727
template<typename T, typename Device>
2828
DiagoBPCG<T, Device>::~DiagoBPCG() {
2929
// Note, we do not need to free the h_prec and psi pointer as they are refs to the outside data
30-
delete this->grad_wrapper;
3130
}
3231

3332
template<typename T, typename Device>
34-
void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
33+
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
3534
// Specify the problem size n_basis, n_band, while lda is n_basis
36-
this->n_band = psi_in.get_nbands();
37-
this->n_basis = psi_in.get_nbasis();
35+
this->n_band = nband;
36+
this->n_basis = nbasis;
37+
3838

3939
// All column major tensors
4040

@@ -51,9 +51,7 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
5151

5252
this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis}));
5353

54-
//TODO: Remove class Psi, using ct::Tensor instead!
55-
this->grad_wrapper = new psi::Psi<T, Device>(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer());
56-
this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
54+
this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis}));
5755
}
5856

5957
template<typename T, typename Device>
@@ -174,16 +172,12 @@ void DiagoBPCG<T, Device>::rotate_wf(
174172

175173
template<typename T, typename Device>
176174
void DiagoBPCG<T, Device>::calc_hpsi_with_block(
177-
hamilt::Hamilt<T, Device>* hamilt_in,
178-
const psi::Psi<T, Device>& psi_in,
175+
const HPsiFunc& hpsi_func,
176+
T *psi_in,
179177
ct::Tensor& hpsi_out)
180178
{
181179
// calculate all-band hpsi
182-
psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1);
183-
hpsi_info info(&psi_in, all_bands_range, hpsi_out.data<T>());
184-
hamilt_in->ops->hPsi(info);
185-
186-
return;
180+
hpsi_func(psi_in, hpsi_out.data<T>(), this->n_basis, this->n_band);
187181
}
188182

189183
template<typename T, typename Device>
@@ -207,16 +201,16 @@ void DiagoBPCG<T, Device>::diag_hsub(
207201

208202
template<typename T, typename Device>
209203
void DiagoBPCG<T, Device>::calc_hsub_with_block(
210-
hamilt::Hamilt<T, Device> *hamilt_in,
211-
const psi::Psi<T, Device> &psi_in,
204+
const HPsiFunc& hpsi_func,
205+
T *psi_in,
212206
ct::Tensor& psi_out,
213207
ct::Tensor& hpsi_out,
214208
ct::Tensor& hsub_out,
215209
ct::Tensor& workspace_in,
216210
ct::Tensor& eigenvalue_out)
217211
{
218212
// Apply the H operator to psi and obtain the hpsi matrix.
219-
this->calc_hpsi_with_block(hamilt_in, psi_in, hpsi_out);
213+
this->calc_hpsi_with_block(hpsi_func, psi_in, hpsi_out);
220214

221215
// Diagonalization of the subspace matrix.
222216
this->diag_hsub(psi_out,hpsi_out, hsub_out, eigenvalue_out);
@@ -250,19 +244,19 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
250244

251245
template<typename T, typename Device>
252246
void DiagoBPCG<T, Device>::diag(
253-
hamilt::Hamilt<T, Device>* hamilt_in,
254-
psi::Psi<T, Device>& psi_in,
247+
const HPsiFunc& hpsi_func,
248+
T *psi_in,
255249
Real* eigenvalue_in)
256250
{
257251
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
258252
// Get the pointer of the input psi
259-
this->psi = std::move(ct::TensorMap(psi_in.get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
253+
this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis}));
260254

261255
// Update the precondition array
262256
this->calc_prec();
263257

264258
// Improving the initial guess of the wave function psi through a subspace diagonalization.
265-
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
259+
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
266260

267261
setmem_complex_op()(this->grad_old.template data<T>(), 0, this->n_basis * this->n_band);
268262

@@ -293,7 +287,7 @@ void DiagoBPCG<T, Device>::diag(
293287
syncmem_complex_op()(this->grad_old.template data<T>(), this->grad.template data<T>(), n_basis * n_band);
294288

295289
// Calculate H|grad> matrix
296-
this->calc_hpsi_with_block(hamilt_in, this->grad_wrapper[0], this->hgrad);
290+
this->calc_hpsi_with_block(hpsi_func, this->grad.template data<T>(), /*this->grad_wrapper[0],*/ this->hgrad);
297291

298292
// optimize psi as well as the hpsi
299293
// 1. normalize grad
@@ -305,7 +299,7 @@ void DiagoBPCG<T, Device>::diag(
305299
this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
306300

307301
if (current_scf_iter == 1 && ntry % this->nline == 0) {
308-
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
302+
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
309303
}
310304
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));
311305

source/module_hsolver/diago_bpcg.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,24 @@ class DiagoBPCG
5050
* This function allocates all the related variables, such as hpsi, hsub, before the diag call.
5151
* It is called by the HsolverPW::initDiagh() function.
5252
*
53-
* @param psi_in The input wavefunction psi.
53+
* @param nband The number of bands.
54+
* @param nbasis The number of basis functions. Leading dimension of psi.
5455
*/
55-
void init_iter(const psi::Psi<T, Device> &psi_in);
56+
void init_iter(const int nband, const int nbasis);
57+
58+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5659

5760
/**
5861
* @brief Diagonalize the Hamiltonian using the BPCG method.
5962
*
6063
* This function is called by the HsolverPW::solve() function.
6164
*
62-
* @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
63-
* @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major].
65+
* @param hpsi_func A function computing the product of the Hamiltonian matrix H
66+
* and a wavefunction blockvector X.
67+
* @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major].
6468
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
6569
*/
66-
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in);
70+
void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in);
6771

6872

6973
private:
@@ -103,7 +107,6 @@ class DiagoBPCG
103107
/// work for some calculations within this class, including rotate_wf call
104108
ct::Tensor work = {};
105109

106-
psi::Psi<T, Device>* grad_wrapper;
107110
/**
108111
* @brief Update the precondition array.
109112
*
@@ -134,13 +137,14 @@ class DiagoBPCG
134137
* psi_in[dim: n_basis x n_band, column major, lda = n_basis_max],
135138
* hpsi_out[dim: n_basis x n_band, column major, lda = n_basis_max].
136139
*
137-
* @param hamilt_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
140+
* @param hpsi_func A function computing the product of the Hamiltonian matrix H
141+
* and a wavefunction blockvector X.
138142
* @param psi_in The input wavefunction psi.
139143
* @param hpsi_out Pointer to the array where the resulting hpsi matrix will be stored.
140144
*/
141145
void calc_hpsi_with_block(
142-
hamilt::Hamilt<T, Device>* hamilt_in,
143-
const psi::Psi<T, Device>& psi_in,
146+
const HPsiFunc& hpsi_func,
147+
T *psi_in,
144148
ct::Tensor& hpsi_out);
145149

146150
/**
@@ -220,16 +224,16 @@ class DiagoBPCG
220224
* hsub_out[dim: n_band x n_band, column major, lda = n_band],
221225
* eigenvalue_out[dim: n_basis_max, column major].
222226
*
223-
* @param hamilt_in Pointer to the Hamiltonian object.
224-
* @param psi_in Input wavefunction.
227+
* @param hpsi_func A function computing the product of matrix H and wavefunction blockvector X.
228+
* @param psi_in Input wavefunction pointer.
225229
* @param psi_out Output wavefunction.
226230
* @param hpsi_out Product of psi_out and Hamiltonian.
227231
* @param hsub_out Subspace matrix output.
228232
* @param eigenvalue_out Computed eigen.
229233
*/
230234
void calc_hsub_with_block(
231-
hamilt::Hamilt<T, Device>* hamilt_in,
232-
const psi::Psi<T, Device>& psi_in,
235+
const HPsiFunc& hpsi_func,
236+
T *psi_in,
233237
ct::Tensor& psi_out, ct::Tensor& hpsi_out,
234238
ct::Tensor& hsub_out, ct::Tensor& workspace_in,
235239
ct::Tensor& eigenvalue_out);
@@ -314,8 +318,6 @@ class DiagoBPCG
314318
*/
315319
bool test_error(const ct::Tensor& err_in, Real thr_in);
316320

317-
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
318-
319321
using ct_Device = typename ct::PsiToContainer<Device>::type;
320322
using setmem_var_op = ct::kernels::set_memory<Real, ct_Device>;
321323
using resmem_var_op = ct::kernels::resize_memory<Real, ct_Device>;

source/module_hsolver/hsolver_pw.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,27 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
467467
}
468468
else if (this->method == "bpcg")
469469
{
470+
const int nband = psi.get_nbands();
471+
const int nbasis = psi.get_nbasis();
472+
auto ngk_pointer = psi.get_ngk_pointer();
473+
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
474+
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
475+
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
476+
477+
// Convert "pointer data stucture" to a psi::Psi object
478+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
479+
480+
psi::Range bands_range(true, 0, 0, nvec - 1);
481+
482+
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
483+
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
484+
hm->ops->hPsi(info);
485+
486+
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
487+
};
470488
DiagoBPCG<T, Device> bpcg(pre_condition.data());
471-
bpcg.init_iter(psi);
472-
bpcg.diag(hm, psi, eigenvalue);
489+
bpcg.init_iter(nband, nbasis);
490+
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
473491
}
474492
else if (this->method == "dav_subspace")
475493
{

source/module_hsolver/test/diago_bpcg_test.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,32 @@ class DiagoBPCGPrepare
130130
psi_local.fix_k(0);
131131
double start, end;
132132
start = MPI_Wtime();
133-
bpcg.init_iter(psi_local);
134-
bpcg.diag(ha,psi_local,en);
135-
bpcg.diag(ha,psi_local,en);
136-
bpcg.diag(ha,psi_local,en);
133+
using T = std::complex<double>;
134+
const int dim = DIAGOTEST::npw;
135+
const std::vector<T> &h_mat = DIAGOTEST::hmatrix_local;
136+
auto hpsi_func = [h_mat, dim](T *psi_in, T *hpsi_out,
137+
const int ld_psi, const int nvec) {
138+
auto one = std::make_unique<T>(1.0);
139+
auto zero = std::make_unique<T>(0.0);
140+
const T *one_ = one.get();
141+
const T *zero_ = zero.get();
142+
143+
base_device::DEVICE_CPU *ctx = {};
144+
// hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec)
145+
hsolver::gemm_op<T, base_device::DEVICE_CPU>()(
146+
ctx, 'N', 'N',
147+
dim, nvec, dim,
148+
one_,
149+
h_mat.data(), dim,
150+
psi_in, ld_psi,
151+
zero_,
152+
hpsi_out, ld_psi);
153+
};
154+
bpcg.init_iter(nband, npw);
155+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
156+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
157+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
158+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
137159
end = MPI_Wtime();
138160
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
139161
delete [] DIAGOTEST::npw_local;
@@ -219,29 +241,6 @@ TEST(DiagoBPCGTest, Hamilt)
219241
}
220242
}*/
221243

222-
// bpcg for a 2x2 matrix
223-
#ifdef __MPI
224-
#else
225-
TEST(DiagoBPCGTest, TwoByTwo)
226-
{
227-
int dim = 2;
228-
int nband = 2;
229-
ModuleBase::ComplexMatrix hm(2, 2);
230-
hm(0, 0) = std::complex<double>{4.0, 0.0};
231-
hm(0, 1) = std::complex<double>{1.0, 0.0};
232-
hm(1, 0) = std::complex<double>{1.0, 0.0};
233-
hm(1, 1) = std::complex<double>{3.0, 0.0};
234-
// nband, npw, sub, sparsity, reorder, eps, maxiter, threshold
235-
DiagoBPCGPrepare dcp(nband, dim, 0, true, 1e-4, 50, 1e-10);
236-
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = dcp.maxiter;
237-
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR = dcp.eps;
238-
HPsi<std::complex<double>> hpsi;
239-
hpsi.create(nband, dim);
240-
DIAGOTEST::hmatrix = hm;
241-
DIAGOTEST::npw = dim;
242-
dcp.CompareEigen(hpsi.precond());
243-
}
244-
#endif
245244

246245
TEST(DiagoBPCGTest, readH)
247246
{

0 commit comments

Comments
 (0)