Skip to content
14 changes: 10 additions & 4 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Operator<T, Device>::Operator(){}
template<typename T, typename Device>
Operator<T, Device>::~Operator()
{
if(this->hpsi != nullptr) delete this->hpsi;
if(this->hpsi != nullptr) { delete this->hpsi;
}
Operator* last = this->next_op;
Operator* last_sub = this->next_sub_op;
while(last != nullptr || last_sub != nullptr)
Expand Down Expand Up @@ -61,8 +62,11 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
}

auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {

// a "psi" with the bands of needed range
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis());
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);


switch (op->get_act_type())
{
case 2:
Expand Down Expand Up @@ -100,9 +104,11 @@ void Operator<T, Device>::init(const int ik_in)
template<typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
{
if(next==nullptr) return;
if(next==nullptr) { return;
}
next->is_first_node = false;
if(next->next_op != nullptr) this->add(next->next_op);
if(next->next_op != nullptr) { this->add(next->next_op);
}
Operator* last = this;
//loop to end of the chain
while(last->next_op != nullptr)
Expand Down
40 changes: 22 additions & 18 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands());
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks);

if (skip_charge)
{
Expand Down Expand Up @@ -357,19 +357,27 @@ template <typename T, typename Device>
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
std::vector<Real>& pre_condition,
Real* eigenvalue)
Real* eigenvalue,
const int& nk_nums)
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, this->rank_in_pool, this->nproc_in_pool};
#else
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
#endif

auto ngk_pointer = psi.get_ngk_pointer();

std::vector<int> ngk_vector(nk_nums, 0);
for (int i = 0; i < nk_nums; i++)
{
ngk_vector[i] = ngk_pointer[i];
}

if (this->method == "cg")
{
// wrap the subspace_func into a lambda function
auto ngk_pointer = psi.get_ngk_pointer();
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
Expand All @@ -379,12 +387,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1),
ngk_pointer);
ngk_vector);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1),
ngk_pointer);
ngk_vector);
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));
Expand All @@ -403,7 +411,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
using ct_Device = typename ct::PsiToContainer<Device>::type;

// wrap the hpsi_func and spsi_func into a lambda function
auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
Expand All @@ -414,7 +422,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ngk_pointer);
ngk_vector);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
Expand Down Expand Up @@ -473,13 +481,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
{
const int nband = psi.get_nbands();
const int nbasis = psi.get_nbasis();
auto ngk_pointer = psi.get_ngk_pointer();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand All @@ -495,13 +502,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
}
else if (this->method == "dav_subspace")
{
auto ngk_pointer = psi.get_ngk_pointer();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand Down Expand Up @@ -546,15 +552,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi

// Davidson matrix-blockvector functions

auto ngk_pointer = psi.get_ngk_pointer();
/// wrap hpsi into lambda function, Matrix \times blockvector
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("David", "hpsi_func");

// Convert pointer of psi_in to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);

psi::Range bands_range(true, 0, 0, nvec - 1);

Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class HSolverPW
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
std::vector<Real>& pre_condition,
Real* eigenvalue);
Real* eigenvalue,
const int& nk_nums);

// calculate the precondition array for diagonalization in PW base
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0());
/// solve eigenvector and eigenvalue for H(k)
double* p_eigenvalues = &(pes->ekb(ik, 0));
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues);
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues, nks);
}

#ifdef __MPI
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_cg_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class DiagoCGPrepare
auto psi_wrapper = psi::Psi<std::complex<float>>(
psi_in.data<std::complex<float>>(), 1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<float>>());
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_cg_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class DiagoCGPrepare
auto psi_wrapper = psi::Psi<double>(
psi_in.data<double>(), 1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<double>());
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_cg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class DiagoCGPrepare
auto psi_wrapper = psi::Psi<std::complex<double>>(
psi_in.data<std::complex<double>>(), 1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<double>>());
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DiagoDavPrepare
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
const int ld_psi, const int nvec)
{
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, nullptr);
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, true);
psi::Range bands_range(true, 0, 0, nvec-1);
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class DiagoDavPrepare
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
const int ld_psi, const int nvec)
{
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, nullptr);
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, true);
psi::Range bands_range(true, 0, 0, nvec-1);
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class DiagoDavPrepare
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
const int ld_psi, const int nvec)
{
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, nullptr);
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, true);
psi::Range bands_range(true, 0, 0, nvec-1);
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/get_wf_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ void IState_Envelope::begin(const UnitCell& ucell,
printf(" Estimated on-the-fly memory consuming by IState_Envelope::begin::wfc_k_grid: %f MB\n", mem_size);

// for pw_wfc in G space
psi::Psi<std::complex<double>> psi_g(kv.ngk.data());
psi::Psi<std::complex<double>> psi_g;
if (out_wf || out_wf_r)
{
psi_g.resize(nks, nbands, pw_wfc->npwk_max);
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/test/write_wfc_nao_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class WriteWfcLcaoTest : public testing::Test
TEST_F(WriteWfcLcaoTest, WriteWfcLcao)
{
// create a psi object
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local);
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, true);
PARAM.sys.global_out_dir = "./";
ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1);

Expand Down Expand Up @@ -196,7 +196,7 @@ TEST_F(WriteWfcLcaoTest, WriteWfcLcao)

TEST_F(WriteWfcLcaoTest, WriteWfcLcaoComplex)
{
psi::Psi<std::complex<double>> my_psi(psi_local_complex.data(), nk, nbands_local, nbasis_local);
psi::Psi<std::complex<double>> my_psi(psi_local_complex.data(), nk, nbands_local, nbasis_local, true);
PARAM.sys.global_out_dir = "./";
ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1);

Expand Down
6 changes: 4 additions & 2 deletions source/module_lr/utils/lr_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ namespace LR_Util
{
assert(psi_kfirst.get_nk() == 1);
assert(nk_in * nbasis_in == psi_kfirst.get_nbasis());

int ib_now = psi_kfirst.get_current_b();
psi_kfirst.fix_b(0); // for get_pointer() to get the head pointer
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, psi_kfirst.get_ngk_pointer(), false);
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, false);
psi_kfirst.fix_b(ib_now);
return psi_bfirst;
}
Expand All @@ -121,8 +122,9 @@ namespace LR_Util
{
int ib_now = psi_bfirst.get_current_b();
int ik_now = psi_bfirst.get_current_k();

psi_bfirst.fix_kb(0, 0); // for get_pointer() to get the head pointer
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), psi_bfirst.get_ngk_pointer(), true);
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), true);
psi_bfirst.fix_kb(ik_now, ib_now);
return psi_kfirst;
}
Expand Down
Loading
Loading