Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
template<typename T, typename Device>
DiagoBPCG<T, Device>::~DiagoBPCG() {
// Note, we do not need to free the h_prec and psi pointer as they are refs to the outside data
delete this->grad_wrapper;
}

template<typename T, typename Device>
void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
// Specify the problem size n_basis, n_band, while lda is n_basis
this->n_band = psi_in.get_nbands();
this->n_basis = psi_in.get_nbasis();
this->n_band = nband;
this->n_basis = nbasis;


// All column major tensors

Expand All @@ -51,9 +51,7 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {

this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis}));

//TODO: Remove class Psi, using ct::Tensor instead!
this->grad_wrapper = new psi::Psi<T, Device>(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer());
this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis}));
}

template<typename T, typename Device>
Expand Down Expand Up @@ -174,16 +172,12 @@ void DiagoBPCG<T, Device>::rotate_wf(

template<typename T, typename Device>
void DiagoBPCG<T, Device>::calc_hpsi_with_block(
hamilt::Hamilt<T, Device>* hamilt_in,
const psi::Psi<T, Device>& psi_in,
const HPsiFunc& hpsi_func,
T *psi_in,
ct::Tensor& hpsi_out)
{
// calculate all-band hpsi
psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1);
hpsi_info info(&psi_in, all_bands_range, hpsi_out.data<T>());
hamilt_in->ops->hPsi(info);

return;
hpsi_func(psi_in, hpsi_out.data<T>(), this->n_basis, this->n_band);
}

template<typename T, typename Device>
Expand All @@ -207,16 +201,16 @@ void DiagoBPCG<T, Device>::diag_hsub(

template<typename T, typename Device>
void DiagoBPCG<T, Device>::calc_hsub_with_block(
hamilt::Hamilt<T, Device> *hamilt_in,
const psi::Psi<T, Device> &psi_in,
const HPsiFunc& hpsi_func,
T *psi_in,
ct::Tensor& psi_out,
ct::Tensor& hpsi_out,
ct::Tensor& hsub_out,
ct::Tensor& workspace_in,
ct::Tensor& eigenvalue_out)
{
// Apply the H operator to psi and obtain the hpsi matrix.
this->calc_hpsi_with_block(hamilt_in, psi_in, hpsi_out);
this->calc_hpsi_with_block(hpsi_func, psi_in, hpsi_out);

// Diagonalization of the subspace matrix.
this->diag_hsub(psi_out,hpsi_out, hsub_out, eigenvalue_out);
Expand Down Expand Up @@ -250,19 +244,19 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(

template<typename T, typename Device>
void DiagoBPCG<T, Device>::diag(
hamilt::Hamilt<T, Device>* hamilt_in,
psi::Psi<T, Device>& psi_in,
const HPsiFunc& hpsi_func,
T *psi_in,
Real* eigenvalue_in)
{
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
// Get the pointer of the input psi
this->psi = std::move(ct::TensorMap(psi_in.get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis}));

// Update the precondition array
this->calc_prec();

// Improving the initial guess of the wave function psi through a subspace diagonalization.
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);

setmem_complex_op()(this->grad_old.template data<T>(), 0, this->n_basis * this->n_band);

Expand Down Expand Up @@ -293,7 +287,7 @@ void DiagoBPCG<T, Device>::diag(
syncmem_complex_op()(this->grad_old.template data<T>(), this->grad.template data<T>(), n_basis * n_band);

// Calculate H|grad> matrix
this->calc_hpsi_with_block(hamilt_in, this->grad_wrapper[0], this->hgrad);
this->calc_hpsi_with_block(hpsi_func, this->grad.template data<T>(), /*this->grad_wrapper[0],*/ this->hgrad);

// optimize psi as well as the hpsi
// 1. normalize grad
Expand All @@ -305,7 +299,7 @@ void DiagoBPCG<T, Device>::diag(
this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);

if (current_scf_iter == 1 && ntry % this->nline == 0) {
this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
}
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));

Expand Down
32 changes: 17 additions & 15 deletions source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,24 @@ class DiagoBPCG
* This function allocates all the related variables, such as hpsi, hsub, before the diag call.
* It is called by the HsolverPW::initDiagh() function.
*
* @param psi_in The input wavefunction psi.
* @param nband The number of bands.
* @param nbasis The number of basis functions. Leading dimension of psi.
*/
void init_iter(const psi::Psi<T, Device> &psi_in);
void init_iter(const int nband, const int nbasis);

using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

/**
* @brief Diagonalize the Hamiltonian using the BPCG method.
*
* This function is called by the HsolverPW::solve() function.
*
* @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
* @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param hpsi_func A function computing the product of the Hamiltonian matrix H
* and a wavefunction blockvector X.
* @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
*/
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in);
void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in);


private:
Expand Down Expand Up @@ -103,7 +107,6 @@ class DiagoBPCG
/// work for some calculations within this class, including rotate_wf call
ct::Tensor work = {};

psi::Psi<T, Device>* grad_wrapper;
/**
* @brief Update the precondition array.
*
Expand Down Expand Up @@ -134,13 +137,14 @@ class DiagoBPCG
* psi_in[dim: n_basis x n_band, column major, lda = n_basis_max],
* hpsi_out[dim: n_basis x n_band, column major, lda = n_basis_max].
*
* @param hamilt_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
* @param hpsi_func A function computing the product of the Hamiltonian matrix H
* and a wavefunction blockvector X.
* @param psi_in The input wavefunction psi.
* @param hpsi_out Pointer to the array where the resulting hpsi matrix will be stored.
*/
void calc_hpsi_with_block(
hamilt::Hamilt<T, Device>* hamilt_in,
const psi::Psi<T, Device>& psi_in,
const HPsiFunc& hpsi_func,
T *psi_in,
ct::Tensor& hpsi_out);

/**
Expand Down Expand Up @@ -220,16 +224,16 @@ class DiagoBPCG
* hsub_out[dim: n_band x n_band, column major, lda = n_band],
* eigenvalue_out[dim: n_basis_max, column major].
*
* @param hamilt_in Pointer to the Hamiltonian object.
* @param psi_in Input wavefunction.
* @param hpsi_func A function computing the product of matrix H and wavefunction blockvector X.
* @param psi_in Input wavefunction pointer.
* @param psi_out Output wavefunction.
* @param hpsi_out Product of psi_out and Hamiltonian.
* @param hsub_out Subspace matrix output.
* @param eigenvalue_out Computed eigen.
*/
void calc_hsub_with_block(
hamilt::Hamilt<T, Device>* hamilt_in,
const psi::Psi<T, Device>& psi_in,
const HPsiFunc& hpsi_func,
T *psi_in,
ct::Tensor& psi_out, ct::Tensor& hpsi_out,
ct::Tensor& hsub_out, ct::Tensor& workspace_in,
ct::Tensor& eigenvalue_out);
Expand Down Expand Up @@ -314,8 +318,6 @@ class DiagoBPCG
*/
bool test_error(const ct::Tensor& err_in, Real thr_in);

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;

using ct_Device = typename ct::PsiToContainer<Device>::type;
using setmem_var_op = ct::kernels::set_memory<Real, ct_Device>;
using resmem_var_op = ct::kernels::resize_memory<Real, ct_Device>;
Expand Down
22 changes: 20 additions & 2 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,27 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
}
else if (this->method == "bpcg")
{
const int nband = psi.get_nbands();
const int nbasis = psi.get_nbasis();
auto ngk_pointer = psi.get_ngk_pointer();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);

psi::Range bands_range(true, 0, 0, nvec - 1);

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
hm->ops->hPsi(info);

ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
bpcg.init_iter(nband, nbasis);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
}
else if (this->method == "dav_subspace")
{
Expand Down
53 changes: 26 additions & 27 deletions source/module_hsolver/test/diago_bpcg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,32 @@ class DiagoBPCGPrepare
psi_local.fix_k(0);
double start, end;
start = MPI_Wtime();
bpcg.init_iter(psi_local);
bpcg.diag(ha,psi_local,en);
bpcg.diag(ha,psi_local,en);
bpcg.diag(ha,psi_local,en);
using T = std::complex<double>;
const int dim = DIAGOTEST::npw;
const std::vector<T> &h_mat = DIAGOTEST::hmatrix_local;
auto hpsi_func = [h_mat, dim](T *psi_in, T *hpsi_out,
const int ld_psi, const int nvec) {
auto one = std::make_unique<T>(1.0);
auto zero = std::make_unique<T>(0.0);
const T *one_ = one.get();
const T *zero_ = zero.get();

base_device::DEVICE_CPU *ctx = {};
// hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec)
hsolver::gemm_op<T, base_device::DEVICE_CPU>()(
ctx, 'N', 'N',
dim, nvec, dim,
one_,
h_mat.data(), dim,
psi_in, ld_psi,
zero_,
hpsi_out, ld_psi);
};
bpcg.init_iter(nband, npw);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
end = MPI_Wtime();
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
delete [] DIAGOTEST::npw_local;
Expand Down Expand Up @@ -219,29 +241,6 @@ TEST(DiagoBPCGTest, Hamilt)
}
}*/

// bpcg for a 2x2 matrix
#ifdef __MPI
#else
TEST(DiagoBPCGTest, TwoByTwo)
{
int dim = 2;
int nband = 2;
ModuleBase::ComplexMatrix hm(2, 2);
hm(0, 0) = std::complex<double>{4.0, 0.0};
hm(0, 1) = std::complex<double>{1.0, 0.0};
hm(1, 0) = std::complex<double>{1.0, 0.0};
hm(1, 1) = std::complex<double>{3.0, 0.0};
// nband, npw, sub, sparsity, reorder, eps, maxiter, threshold
DiagoBPCGPrepare dcp(nband, dim, 0, true, 1e-4, 50, 1e-10);
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = dcp.maxiter;
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR = dcp.eps;
HPsi<std::complex<double>> hpsi;
hpsi.create(nband, dim);
DIAGOTEST::hmatrix = hm;
DIAGOTEST::npw = dim;
dcp.CompareEigen(hpsi.precond());
}
#endif

TEST(DiagoBPCGTest, readH)
{
Expand Down
Loading