Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Operator<T, Device>::Operator(){}
template<typename T, typename Device>
Operator<T, Device>::~Operator()
{
if(this->hpsi != nullptr) delete this->hpsi;
if(this->hpsi != nullptr) { delete this->hpsi;
}
Operator* last = this->next_op;
Operator* last_sub = this->next_sub_op;
while(last != nullptr || last_sub != nullptr)
Expand Down Expand Up @@ -69,7 +70,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
op->act(psi_wrapper, *this->hpsi, nbands);
break;
default:
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ik_nbas(op->ik), is_first_node);
break;
}
};
Expand Down Expand Up @@ -100,9 +101,11 @@ void Operator<T, Device>::init(const int ik_in)
template<typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
{
if(next==nullptr) return;
if(next==nullptr) { return;
}
next->is_first_node = false;
if(next->next_op != nullptr) this->add(next->next_op);
if(next->next_op != nullptr) { this->add(next->next_op);
}
Operator* last = this;
//loop to end of the chain
while(last->next_op != nullptr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void Velocity::act
) const
{
ModuleBase::timer::tick("Operator", "Velocity");
const int npw = psi_in->get_ngk(this->ik);
const int npw = psi_in->get_ik_nbas(this->ik);
const int max_npw = psi_in->get_nbasis() / psi_in->npol;
const int npol = psi_in->npol;
const std::complex<double>* tmpsi_in = psi0;
Expand Down
11 changes: 6 additions & 5 deletions source/module_hsolver/diago_iter_assist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*

if (base_device::get_device_type(ctx) == base_device::GpuDevice)
{
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ngk(0));
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ik_nbas(0));
T* ppsi = psi_temp.get_pointer();
// hpsi and spsi share the temp space
T* temp = nullptr;
Expand All @@ -212,7 +212,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
{
// psi_temp is one band psi, psi is all bands psi, the range always is 1 for the only band in psi_temp
syncmem_complex_op()(ctx, ctx, ppsi, psi + i * psi_nc, psi_nc);
psi::Range band_by_band_range(1, 0, 0, 0);
psi::Range band_by_band_range(true, 0, 0, 0);
hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi);

// H|Psi> to get hpsi for target band
Expand Down Expand Up @@ -246,7 +246,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
}
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
{
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0));
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ik_nbas(0));
T* ppsi = psi_temp.get_pointer();
syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size());
// hpsi and spsi share the temp space
Expand All @@ -256,7 +256,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*

T* hpsi = temp;
// do hPsi for all bands
psi::Range all_bands_range(1, 0, 0, nstart - 1);
psi::Range all_bands_range(true, 0, 0, nstart - 1);
hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi);
pHamilt->ops->hPsi(hpsi_in);

Expand Down Expand Up @@ -586,8 +586,9 @@ bool DiagoIterAssist<T, Device>::test_exit_cond(const int& ntry, const int& notc
//================================================================

bool scf = true;
if (PARAM.inp.calculation == "nscf")
if (PARAM.inp.calculation == "nscf") {
scf = false;
}

// If ntry <=5, try to do it better, if ntry > 5, exit.
const bool f1 = (ntry <= 5);
Expand Down
12 changes: 6 additions & 6 deletions source/module_io/unk_overlap_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ std::complex<double> unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw,
ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw);


for (int igl = 0; igl < evc->get_ngk(ik_L); igl++)
for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++)
{
unk_L[wfcpw->getigl2ig(ik_L,igl)] = evc[0](ik_L, iband_L, igl);
}

for (int igl = 0; igl < evc->get_ngk(ik_R); igl++)
for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++)
{
unk_R[wfcpw->getigl2ig(ik_R,igl)] = evc[0](ik_R, iband_R, igl);
}
Expand Down Expand Up @@ -103,7 +103,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_r, psi_r, ik_R);

for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
for (int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++)
{
result = result + conj(psi_r[ig]) * evc[0](ik_R, iband_R, ig);
}
Expand Down Expand Up @@ -143,12 +143,12 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf

for(int i = 0; i < PARAM.globalv.npol; i++)
{
for (int igl = 0; igl < evc->get_ngk(ik_L); igl++)
for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++)
{
unk_L[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_L, iband_L, igl + i * npwx);
}

for (int igl = 0; igl < evc->get_ngk(ik_R); igl++)
for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++)
{
unk_R[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_R, iband_R, igl + i * npwx);
}
Expand Down Expand Up @@ -223,7 +223,7 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho

for (int i = 0; i < PARAM.globalv.npol; i++)
{
for(int ig = 0; ig < evc->get_ngk(ik_R); ig++)
for(int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++)
{
if( i == 0 ) { result = result + conj( psi_up[ig] ) * evc[0](ik_R, iband_R, ig);
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_io/write_vxc_lip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ namespace ModuleIO
hpsi_localxc.fix_k(ik);
#ifdef __DEBUG
assert(hpsi_localxc.get_current_nbas() == psi_pw.get_current_nbas());
assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ngk(ik));
assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ik_nbas(ik));
#endif
/// wrap psi and act band-by-band (the same result as act all bands at once)
// for (int ib = 0;ib < psi_pw.get_nbands();++ib)
// {
// std::cout<<"ib="<<ib<<std::endl;
// psi::Psi<T> psi_single_band(&psi_pw(ik, ib, 0), 1, 1, psi_pw.get_current_nbas());
// psi::Psi<T> hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas());
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik));
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ik_nbas(ik));
// }
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ik_nbas(ik));
delete vxcs_op_pw;
std::vector<T> vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands());
Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands);
Expand Down
54 changes: 31 additions & 23 deletions source/module_psi/psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const
this->nk = nk_in;
this->nbands = nbd_in;
this->nbasis = nbs_in;
this->current_nbasis = nbs_in;
this->psi_current = this->psi = psi_pointer;
this->allocate_inside = false;
// Currently only GPU's implementation is supported for device recording!
Expand Down Expand Up @@ -148,7 +147,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const Psi& psi_in)
psi_in.get_pointer() - psi_in.get_psi_bias(),
psi_in.size());
this->psi_bias = psi_in.get_psi_bias();
this->current_nbasis = psi_in.get_current_nbas();
this->psi_current = this->psi + psi_in.get_psi_bias();
}

Expand Down Expand Up @@ -200,7 +198,6 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
psi_in.size());
}
this->psi_bias = psi_in.get_psi_bias();
this->current_nbasis = psi_in.get_current_nbas();
this->psi_current = this->psi + psi_in.get_psi_bias();
}

Expand All @@ -213,7 +210,6 @@ void Psi<T, Device>::resize(const int nks_in, const int nbands_in, const int nba
this->nk = nks_in;
this->nbands = nbands_in;
this->nbasis = nbasis_in;
this->current_nbasis = nbasis_in;
this->psi_current = this->psi;
// GlobalV::ofs_device << "allocated xxx MB memory for psi" << std::endl;
}
Expand Down Expand Up @@ -276,27 +272,26 @@ template <typename T, typename Device> std::size_t Psi<T, Device>::size() const

template <typename T, typename Device> void Psi<T, Device>::fix_k(const int ik) const
{
assert(ik >= 0);
this->current_k = ik;
if (this->ngk != nullptr && this->npol != 2)
this->current_nbasis = this->ngk[ik];
else
this->current_nbasis = this->nbasis;
assert(ik >= 0 && ik < this->nk);

if (this->k_first)this->current_b = 0;
int base = this->current_b * this->nk * this->nbasis;
if (ik >= this->nk)
if (this->k_first == true)
{
// mem_saver: fix to base
this->psi_bias = base;
this->psi_current = const_cast<T*>(&(this->psi[base]));
this->current_k = ik;
this->current_b = 0;

this->psi_bias = this->current_k * this->nbands * this->nbasis;
this->psi_current = this->psi + this->psi_bias;
}
else
{
this->psi_bias = k_first ? ik * this->nbands * this->nbasis : base + ik * this->nbasis;
this->psi_current = const_cast<T*>(&(this->psi[psi_bias]));
this->current_k = ik;
// this->current_b remains unchanged

this->psi_bias = this->current_b * this->nk * this->nbasis + this->current_k * this->nbasis;
this->psi_current = this->psi + this->psi_bias;
}
}

template <typename T, typename Device> void Psi<T, Device>::fix_b(const int ib) const
{
assert(ib >= 0);
Expand Down Expand Up @@ -366,15 +361,28 @@ template <typename T, typename Device> int Psi<T, Device>::get_current_b() const
return this->current_b;
}

template <typename T, typename Device> int Psi<T, Device>::get_current_nbas() const
template <typename T, typename Device> const int& Psi<T, Device>::get_current_nbas() const
{
return this->current_nbasis;
if (this->ngk != nullptr)
{
return this->ngk[this->current_k];
}
else
{
return this->nbasis;
}
}

template <typename T, typename Device> const int& Psi<T, Device>::get_ngk(const int ik_in) const
template <typename T, typename Device> const int& Psi<T, Device>::get_ik_nbas(const int ik_in) const
{
if (!this->ngk) return this->nbasis;
return this->ngk[ik_in];
if (this->ngk != nullptr)
{
return this->ngk[ik_in];
}
else
{
return this->nbasis;
}
}

template <typename T, typename Device> void Psi<T, Device>::zero_out()
Expand Down
12 changes: 9 additions & 3 deletions source/module_psi/psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,27 @@ class Psi
/// if k_first=true, ikb=ik, ikb2=iband
/// if k_first=false, ikb=iband, ikb2=ik
T& operator()(const int ikb1, const int ikb2, const int ibasis) const;

/// use operator "(ikb2, ibasis)" to reach target element for current k
/// if k_first=true, ikb2=iband
/// if k_first=false, ikb2=ik
T& operator()(const int ikb2, const int ibasis) const;

// use operator "(ibasis)" to reach target element for current k and current band
T& operator()(const int ibasis) const;

// return current k-point index
int get_current_k() const;

// return current band index
int get_current_b() const;
// return current ngk for PW base
int get_current_nbas() const;

const int& get_ngk(const int ik_in) const;
// return the nbasis of current k
const int& get_current_nbas() const;

// return the nbasis of ik_in
const int& get_ik_nbas(const int ik_in) const;

// return ngk array of psi
const int* get_ngk_pointer() const;
// return k_first
Expand Down
10 changes: 5 additions & 5 deletions source/module_psi/test/psi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,23 @@ TEST_F(TestPsi, get_val)
EXPECT_EQ(psi_object14->get_psi_bias(), 0);
}

TEST_F(TestPsi, get_ngk)
TEST_F(TestPsi, get_ik_nbas)
{
psi::Psi<std::complex<double>>* psi_object21 = new psi::Psi<std::complex<double>>(&ngk[0]);
psi::Psi<double>* psi_object22 = new psi::Psi<double>(&ngk[0]);
psi::Psi<std::complex<float>>* psi_object23 = new psi::Psi<std::complex<float>>(&ngk[0]);
psi::Psi<float>* psi_object24 = new psi::Psi<float>(&ngk[0]);

EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]);
EXPECT_EQ(psi_object21->get_ik_nbas(2), ngk[2]);
EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]);

EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]);
EXPECT_EQ(psi_object22->get_ik_nbas(2), ngk[2]);
EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]);

EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]);
EXPECT_EQ(psi_object23->get_ik_nbas(2), ngk[2]);
EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]);

EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]);
EXPECT_EQ(psi_object24->get_ik_nbas(2), ngk[2]);
EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]);
}

Expand Down
Loading