Skip to content

Commit 8823630

Browse files
committed
refactor get_current_nbas get_ik_nbas fix_k
1 parent 0317913 commit 8823630

File tree

8 files changed

+57
-51
lines changed

8 files changed

+57
-51
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6969
op->act(psi_wrapper, *this->hpsi, nbands);
7070
break;
7171
default:
72-
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
72+
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ik_nbas(op->ik), is_first_node);
7373
break;
7474
}
7575
};

source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void Velocity::act
4747
) const
4848
{
4949
ModuleBase::timer::tick("Operator", "Velocity");
50-
const int npw = psi_in->get_ngk(this->ik);
50+
const int npw = psi_in->get_ik_nbas(this->ik);
5151
const int max_npw = psi_in->get_nbasis() / psi_in->npol;
5252
const int npol = psi_in->npol;
5353
const std::complex<double>* tmpsi_in = psi0;

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
199199

200200
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
201201
{
202-
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ngk(0));
202+
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ik_nbas(0));
203203
T* ppsi = psi_temp.get_pointer();
204204
// hpsi and spsi share the temp space
205205
T* temp = nullptr;
@@ -246,7 +246,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
246246
}
247247
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
248248
{
249-
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0));
249+
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ik_nbas(0));
250250
T* ppsi = psi_temp.get_pointer();
251251
syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size());
252252
// hpsi and spsi share the temp space

source/module_io/unk_overlap_pw.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ std::complex<double> unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw,
2929
ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw);
3030

3131

32-
for (int igl = 0; igl < evc->get_ngk(ik_L); igl++)
32+
for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++)
3333
{
3434
unk_L[wfcpw->getigl2ig(ik_L,igl)] = evc[0](ik_L, iband_L, igl);
3535
}
3636

37-
for (int igl = 0; igl < evc->get_ngk(ik_R); igl++)
37+
for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++)
3838
{
3939
unk_R[wfcpw->getigl2ig(ik_R,igl)] = evc[0](ik_R, iband_R, igl);
4040
}
@@ -103,7 +103,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
103103
// (3) calculate the overlap in ik_L and ik_R
104104
wfcpw->real2recip(psi_r, psi_r, ik_R);
105105

106-
for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
106+
for (int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++)
107107
{
108108
result = result + conj(psi_r[ig]) * evc[0](ik_R, iband_R, ig);
109109
}
@@ -143,12 +143,12 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf
143143

144144
for(int i = 0; i < PARAM.globalv.npol; i++)
145145
{
146-
for (int igl = 0; igl < evc->get_ngk(ik_L); igl++)
146+
for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++)
147147
{
148148
unk_L[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_L, iband_L, igl + i * npwx);
149149
}
150150

151-
for (int igl = 0; igl < evc->get_ngk(ik_R); igl++)
151+
for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++)
152152
{
153153
unk_R[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_R, iband_R, igl + i * npwx);
154154
}
@@ -223,7 +223,7 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
223223

224224
for (int i = 0; i < PARAM.globalv.npol; i++)
225225
{
226-
for(int ig = 0; ig < evc->get_ngk(ik_R); ig++)
226+
for(int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++)
227227
{
228228
if( i == 0 ) { result = result + conj( psi_up[ig] ) * evc[0](ik_R, iband_R, ig);
229229
}

source/module_io/write_vxc_lip.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,17 @@ namespace ModuleIO
148148
hpsi_localxc.fix_k(ik);
149149
#ifdef __DEBUG
150150
assert(hpsi_localxc.get_current_nbas() == psi_pw.get_current_nbas());
151-
assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ngk(ik));
151+
assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ik_nbas(ik));
152152
#endif
153153
/// wrap psi and act band-by-band (the same result as act all bands at once)
154154
// for (int ib = 0;ib < psi_pw.get_nbands();++ib)
155155
// {
156156
// std::cout<<"ib="<<ib<<std::endl;
157157
// psi::Psi<T> psi_single_band(&psi_pw(ik, ib, 0), 1, 1, psi_pw.get_current_nbas());
158158
// psi::Psi<T> hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas());
159-
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik));
159+
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ik_nbas(ik));
160160
// }
161-
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
161+
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ik_nbas(ik));
162162
delete vxcs_op_pw;
163163
std::vector<T> vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands());
164164
Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands);

source/module_psi/psi.cpp

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const
7676
this->nk = nk_in;
7777
this->nbands = nbd_in;
7878
this->nbasis = nbs_in;
79-
this->current_nbasis = nbs_in;
8079
this->psi_current = this->psi = psi_pointer;
8180
this->allocate_inside = false;
8281
// Currently only GPU's implementation is supported for device recording!
@@ -148,7 +147,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const Psi& psi_in)
148147
psi_in.get_pointer() - psi_in.get_psi_bias(),
149148
psi_in.size());
150149
this->psi_bias = psi_in.get_psi_bias();
151-
this->current_nbasis = psi_in.get_current_nbas();
152150
this->psi_current = this->psi + psi_in.get_psi_bias();
153151
}
154152

@@ -200,7 +198,6 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
200198
psi_in.size());
201199
}
202200
this->psi_bias = psi_in.get_psi_bias();
203-
this->current_nbasis = psi_in.get_current_nbas();
204201
this->psi_current = this->psi + psi_in.get_psi_bias();
205202
}
206203

@@ -213,7 +210,6 @@ void Psi<T, Device>::resize(const int nks_in, const int nbands_in, const int nba
213210
this->nk = nks_in;
214211
this->nbands = nbands_in;
215212
this->nbasis = nbasis_in;
216-
this->current_nbasis = nbasis_in;
217213
this->psi_current = this->psi;
218214
// GlobalV::ofs_device << "allocated xxx MB memory for psi" << std::endl;
219215
}
@@ -276,27 +272,26 @@ template <typename T, typename Device> std::size_t Psi<T, Device>::size() const
276272

277273
template <typename T, typename Device> void Psi<T, Device>::fix_k(const int ik) const
278274
{
279-
assert(ik >= 0);
280-
this->current_k = ik;
281-
if (this->ngk != nullptr && this->npol != 2)
282-
this->current_nbasis = this->ngk[ik];
283-
else
284-
this->current_nbasis = this->nbasis;
275+
assert(ik >= 0 && ik < this->nk);
285276

286-
if (this->k_first)this->current_b = 0;
287-
int base = this->current_b * this->nk * this->nbasis;
288-
if (ik >= this->nk)
277+
if (this->k_first == true)
289278
{
290-
// mem_saver: fix to base
291-
this->psi_bias = base;
292-
this->psi_current = const_cast<T*>(&(this->psi[base]));
279+
this->current_k = ik;
280+
this->current_b = 0;
281+
282+
this->psi_bias = this->current_k * this->nbands * this->nbasis;
283+
this->psi_current = this->psi + this->psi_bias;
293284
}
294285
else
295286
{
296-
this->psi_bias = k_first ? ik * this->nbands * this->nbasis : base + ik * this->nbasis;
297-
this->psi_current = const_cast<T*>(&(this->psi[psi_bias]));
287+
this->current_k = ik;
288+
// this->current_b remains unchanged
289+
290+
this->psi_bias = this->current_b * this->nk * this->nbasis + this->current_k * this->nbasis;
291+
this->psi_current = this->psi + this->psi_bias;
298292
}
299293
}
294+
300295
template <typename T, typename Device> void Psi<T, Device>::fix_b(const int ib) const
301296
{
302297
assert(ib >= 0);
@@ -350,12 +345,6 @@ template <typename T, typename Device> T& Psi<T, Device>::operator()(const int i
350345
return this->psi_current[ikb2 * this->nbasis + ibasis];
351346
}
352347

353-
template <typename T, typename Device> T& Psi<T, Device>::operator()(const int ibasis) const
354-
{
355-
assert(ibasis >= 0 && ibasis < this->nbasis);
356-
return this->psi_current[ibasis];
357-
}
358-
359348
template <typename T, typename Device> int Psi<T, Device>::get_current_k() const
360349
{
361350
return this->current_k;
@@ -366,15 +355,28 @@ template <typename T, typename Device> int Psi<T, Device>::get_current_b() const
366355
return this->current_b;
367356
}
368357

369-
template <typename T, typename Device> int Psi<T, Device>::get_current_nbas() const
358+
template <typename T, typename Device> const int& Psi<T, Device>::get_current_nbas() const
370359
{
371-
return this->current_nbasis;
360+
if (this->ngk != nullptr)
361+
{
362+
return this->ngk[this->current_k];
363+
}
364+
else
365+
{
366+
return this->nbasis;
367+
}
372368
}
373369

374-
template <typename T, typename Device> const int& Psi<T, Device>::get_ngk(const int ik_in) const
370+
template <typename T, typename Device> const int& Psi<T, Device>::get_ik_nbas(const int ik_in) const
375371
{
376-
if (!this->ngk) return this->nbasis;
377-
return this->ngk[ik_in];
372+
if (this->ngk != nullptr)
373+
{
374+
return this->ngk[ik_in];
375+
}
376+
else
377+
{
378+
return this->nbasis;
379+
}
378380
}
379381

380382
template <typename T, typename Device> void Psi<T, Device>::zero_out()

source/module_psi/psi.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,20 @@ class Psi
9191
/// if k_first=false, ikb2=ik
9292
T& operator()(const int ikb2, const int ibasis) const;
9393
// use operator "(ibasis)" to reach target element for current k and current band
94-
T& operator()(const int ibasis) const;
94+
// T& operator()(const int ibasis) const;
9595

9696
// return current k-point index
9797
int get_current_k() const;
98+
9899
// return current band index
99100
int get_current_b() const;
100-
// return current ngk for PW base
101-
int get_current_nbas() const;
102101

103-
const int& get_ngk(const int ik_in) const;
102+
// return the nbasis of current k
103+
const int& get_current_nbas() const;
104+
105+
// return the nbasis of ik_in
106+
const int& get_ik_nbas(const int ik_in) const;
107+
104108
// return ngk array of psi
105109
const int* get_ngk_pointer() const;
106110
// return k_first

source/module_psi/test/psi_test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,23 @@ TEST_F(TestPsi, get_val)
6363
EXPECT_EQ(psi_object14->get_psi_bias(), 0);
6464
}
6565

66-
TEST_F(TestPsi, get_ngk)
66+
TEST_F(TestPsi, get_ik_nbas)
6767
{
6868
psi::Psi<std::complex<double>>* psi_object21 = new psi::Psi<std::complex<double>>(&ngk[0]);
6969
psi::Psi<double>* psi_object22 = new psi::Psi<double>(&ngk[0]);
7070
psi::Psi<std::complex<float>>* psi_object23 = new psi::Psi<std::complex<float>>(&ngk[0]);
7171
psi::Psi<float>* psi_object24 = new psi::Psi<float>(&ngk[0]);
7272

73-
EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]);
73+
EXPECT_EQ(psi_object21->get_ik_nbas(2), ngk[2]);
7474
EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]);
7575

76-
EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]);
76+
EXPECT_EQ(psi_object22->get_ik_nbas(2), ngk[2]);
7777
EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]);
7878

79-
EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]);
79+
EXPECT_EQ(psi_object23->get_ik_nbas(2), ngk[2]);
8080
EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]);
8181

82-
EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]);
82+
EXPECT_EQ(psi_object24->get_ik_nbas(2), ngk[2]);
8383
EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]);
8484
}
8585

0 commit comments

Comments
 (0)