diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 63d0956464..aaa21877bc 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -27,14 +27,14 @@ DiagoBPCG::DiagoBPCG(const Real* precondition_in) template DiagoBPCG::~DiagoBPCG() { // Note, we do not need to free the h_prec and psi pointer as they are refs to the outside data - delete this->grad_wrapper; } template -void DiagoBPCG::init_iter(const psi::Psi &psi_in) { +void DiagoBPCG::init_iter(const int nband, const int nbasis) { // Specify the problem size n_basis, n_band, while lda is n_basis - this->n_band = psi_in.get_nbands(); - this->n_basis = psi_in.get_nbasis(); + this->n_band = nband; + this->n_basis = nbasis; + // All column major tensors @@ -51,9 +51,7 @@ void DiagoBPCG::init_iter(const psi::Psi &psi_in) { this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis})); - //TODO: Remove class Psi, using ct::Tensor instead! - this->grad_wrapper = new psi::Psi(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer()); - this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis})); + this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); } template @@ -174,16 +172,12 @@ void DiagoBPCG::rotate_wf( template void DiagoBPCG::calc_hpsi_with_block( - hamilt::Hamilt* hamilt_in, - const psi::Psi& psi_in, + const HPsiFunc& hpsi_func, + T *psi_in, ct::Tensor& hpsi_out) { // calculate all-band hpsi - psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1); - hpsi_info info(&psi_in, all_bands_range, hpsi_out.data()); - hamilt_in->ops->hPsi(info); - - return; + hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band); } template @@ -207,8 +201,8 @@ void DiagoBPCG::diag_hsub( template void DiagoBPCG::calc_hsub_with_block( - hamilt::Hamilt *hamilt_in, - const psi::Psi &psi_in, + const HPsiFunc& hpsi_func, + T *psi_in, ct::Tensor& psi_out, ct::Tensor& hpsi_out, ct::Tensor& hsub_out, @@ -216,7 +210,7 @@ void DiagoBPCG::calc_hsub_with_block( ct::Tensor& eigenvalue_out) { // Apply the H operator to psi and obtain the hpsi matrix. - this->calc_hpsi_with_block(hamilt_in, psi_in, hpsi_out); + this->calc_hpsi_with_block(hpsi_func, psi_in, hpsi_out); // Diagonalization of the subspace matrix. this->diag_hsub(psi_out,hpsi_out, hsub_out, eigenvalue_out); @@ -250,19 +244,19 @@ void DiagoBPCG::calc_hsub_with_block_exit( template void DiagoBPCG::diag( - hamilt::Hamilt* hamilt_in, - psi::Psi& psi_in, + const HPsiFunc& hpsi_func, + T *psi_in, Real* eigenvalue_in) { const int current_scf_iter = hsolver::DiagoIterAssist::SCF_ITER; // Get the pointer of the input psi - this->psi = std::move(ct::TensorMap(psi_in.get_pointer(), t_type, device_type, {this->n_band, this->n_basis})); + this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis})); // Update the precondition array this->calc_prec(); // Improving the initial guess of the wave function psi through a subspace diagonalization. - this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); + this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); setmem_complex_op()(this->grad_old.template data(), 0, this->n_basis * this->n_band); @@ -293,7 +287,7 @@ void DiagoBPCG::diag( syncmem_complex_op()(this->grad_old.template data(), this->grad.template data(), n_basis * n_band); // Calculate H|grad> matrix - this->calc_hpsi_with_block(hamilt_in, this->grad_wrapper[0], this->hgrad); + this->calc_hpsi_with_block(hpsi_func, this->grad.template data(), /*this->grad_wrapper[0],*/ this->hgrad); // optimize psi as well as the hpsi // 1. normalize grad @@ -305,7 +299,7 @@ void DiagoBPCG::diag( this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub); if (current_scf_iter == 1 && ntry % this->nline == 0) { - this->calc_hsub_with_block(hamilt_in, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); + this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); } } while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr)); diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index c14472de2e..2ca8167f9e 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -50,20 +50,24 @@ class DiagoBPCG * This function allocates all the related variables, such as hpsi, hsub, before the diag call. * It is called by the HsolverPW::initDiagh() function. * - * @param psi_in The input wavefunction psi. + * @param nband The number of bands. + * @param nbasis The number of basis functions. Leading dimension of psi. */ - void init_iter(const psi::Psi &psi_in); + void init_iter(const int nband, const int nbasis); + + using HPsiFunc = std::function; /** * @brief Diagonalize the Hamiltonian using the BPCG method. * * This function is called by the HsolverPW::solve() function. * - * @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator. - * @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major]. + * @param hpsi_func A function computing the product of the Hamiltonian matrix H + * and a wavefunction blockvector X. + * @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major]. * @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major]. */ - void diag(hamilt::Hamilt *phm_in, psi::Psi &psi, Real *eigenvalue_in); + void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in); private: @@ -103,7 +107,6 @@ class DiagoBPCG /// work for some calculations within this class, including rotate_wf call ct::Tensor work = {}; - psi::Psi* grad_wrapper; /** * @brief Update the precondition array. * @@ -134,13 +137,14 @@ class DiagoBPCG * psi_in[dim: n_basis x n_band, column major, lda = n_basis_max], * hpsi_out[dim: n_basis x n_band, column major, lda = n_basis_max]. * - * @param hamilt_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator. + * @param hpsi_func A function computing the product of the Hamiltonian matrix H + * and a wavefunction blockvector X. * @param psi_in The input wavefunction psi. * @param hpsi_out Pointer to the array where the resulting hpsi matrix will be stored. */ void calc_hpsi_with_block( - hamilt::Hamilt* hamilt_in, - const psi::Psi& psi_in, + const HPsiFunc& hpsi_func, + T *psi_in, ct::Tensor& hpsi_out); /** @@ -220,16 +224,16 @@ class DiagoBPCG * hsub_out[dim: n_band x n_band, column major, lda = n_band], * eigenvalue_out[dim: n_basis_max, column major]. * - * @param hamilt_in Pointer to the Hamiltonian object. - * @param psi_in Input wavefunction. + * @param hpsi_func A function computing the product of matrix H and wavefunction blockvector X. + * @param psi_in Input wavefunction pointer. * @param psi_out Output wavefunction. * @param hpsi_out Product of psi_out and Hamiltonian. * @param hsub_out Subspace matrix output. * @param eigenvalue_out Computed eigen. */ void calc_hsub_with_block( - hamilt::Hamilt* hamilt_in, - const psi::Psi& psi_in, + const HPsiFunc& hpsi_func, + T *psi_in, ct::Tensor& psi_out, ct::Tensor& hpsi_out, ct::Tensor& hsub_out, ct::Tensor& workspace_in, ct::Tensor& eigenvalue_out); @@ -314,8 +318,6 @@ class DiagoBPCG */ bool test_error(const ct::Tensor& err_in, Real thr_in); - using hpsi_info = typename hamilt::Operator::hpsi_info; - using ct_Device = typename ct::PsiToContainer::type; using setmem_var_op = ct::kernels::set_memory; using resmem_var_op = ct::kernels::resize_memory; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 945a65f87d..4572495ccd 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -467,9 +467,27 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, } else if (this->method == "bpcg") { + const int nband = psi.get_nbands(); + const int nbasis = psi.get_nbasis(); + auto ngk_pointer = psi.get_ngk_pointer(); + // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec + auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + ModuleBase::timer::tick("DavSubspace", "hpsi_func"); + + // Convert "pointer data stucture" to a psi::Psi object + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_pointer); + + psi::Range bands_range(true, 0, 0, nvec - 1); + + using hpsi_info = typename hamilt::Operator::hpsi_info; + hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out); + hm->ops->hPsi(info); + + ModuleBase::timer::tick("DavSubspace", "hpsi_func"); + }; DiagoBPCG bpcg(pre_condition.data()); - bpcg.init_iter(psi); - bpcg.diag(hm, psi, eigenvalue); + bpcg.init_iter(nband, nbasis); + bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue); } else if (this->method == "dav_subspace") { diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 39eacf5db6..6274f75d74 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -130,10 +130,32 @@ class DiagoBPCGPrepare psi_local.fix_k(0); double start, end; start = MPI_Wtime(); - bpcg.init_iter(psi_local); - bpcg.diag(ha,psi_local,en); - bpcg.diag(ha,psi_local,en); - bpcg.diag(ha,psi_local,en); + using T = std::complex; + const int dim = DIAGOTEST::npw; + const std::vector &h_mat = DIAGOTEST::hmatrix_local; + auto hpsi_func = [h_mat, dim](T *psi_in, T *hpsi_out, + const int ld_psi, const int nvec) { + auto one = std::make_unique(1.0); + auto zero = std::make_unique(0.0); + const T *one_ = one.get(); + const T *zero_ = zero.get(); + + base_device::DEVICE_CPU *ctx = {}; + // hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec) + hsolver::gemm_op()( + ctx, 'N', 'N', + dim, nvec, dim, + one_, + h_mat.data(), dim, + psi_in, ld_psi, + zero_, + hpsi_out, ld_psi); + }; + bpcg.init_iter(nband, npw); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en); end = MPI_Wtime(); //if(mypnum == 0) printf("diago time:%7.3f\n",end-start); delete [] DIAGOTEST::npw_local; @@ -219,29 +241,6 @@ TEST(DiagoBPCGTest, Hamilt) } }*/ -// bpcg for a 2x2 matrix -#ifdef __MPI -#else -TEST(DiagoBPCGTest, TwoByTwo) -{ - int dim = 2; - int nband = 2; - ModuleBase::ComplexMatrix hm(2, 2); - hm(0, 0) = std::complex{4.0, 0.0}; - hm(0, 1) = std::complex{1.0, 0.0}; - hm(1, 0) = std::complex{1.0, 0.0}; - hm(1, 1) = std::complex{3.0, 0.0}; - // nband, npw, sub, sparsity, reorder, eps, maxiter, threshold - DiagoBPCGPrepare dcp(nband, dim, 0, true, 1e-4, 50, 1e-10); - hsolver::DiagoIterAssist>::PW_DIAG_NMAX = dcp.maxiter; - hsolver::DiagoIterAssist>::PW_DIAG_THR = dcp.eps; - HPsi> hpsi; - hpsi.create(nband, dim); - DIAGOTEST::hmatrix = hm; - DIAGOTEST::npw = dim; - dcp.CompareEigen(hpsi.precond()); -} -#endif TEST(DiagoBPCGTest, readH) {