Skip to content

Commit 76893ee

Browse files
committed
remove get-ngk in velocity-pw
1 parent 588a335 commit 76893ee

File tree

8 files changed

+45
-73
lines changed

8 files changed

+45
-73
lines changed

source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ void Velocity::act
5959
// }
6060

6161

62-
const int npw = psi_in->get_ngk(this->ik);
63-
// const int npw = psi_in->get_current_nbas();
62+
// const int npw = psi_in->get_ngk(this->ik);
63+
const int npw = psi_in->get_current_nbas();
6464

6565
const int max_npw = psi_in->get_nbasis() / psi_in->npol;
6666
const int npol = psi_in->npol;

source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi<std::complex<float>>& kspsi_all,
172172
const int allbands = bandinfo[5];
173173
const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands;
174174

175-
psi::Psi<std::complex<double>> right_hchi(1, perbands_sto, npwx, p_kv->ngk.data());
176-
psi::Psi<std::complex<float>> f_rightchi(1, perbands_sto, npwx, p_kv->ngk.data());
177-
psi::Psi<std::complex<float>> f_right_hchi(1, perbands_sto, npwx, p_kv->ngk.data());
175+
psi::Psi<std::complex<double>> right_hchi(1, perbands_sto, npwx, npw, true);
176+
psi::Psi<std::complex<float>> f_rightchi(1, perbands_sto, npwx, npw, true);
177+
psi::Psi<std::complex<float>> f_right_hchi(1, perbands_sto, npwx, npw, true);
178178

179179
this->p_hamilt_sto->hPsi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto);
180180
this->p_hamilt_sto->hPsi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto);
@@ -206,8 +206,8 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi<std::complex<float>>& kspsi_all,
206206
}
207207
#endif
208208

209-
psi::Psi<std::complex<float>> f_batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data());
210-
psi::Psi<std::complex<float>> f_batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data());
209+
psi::Psi<std::complex<float>> f_batch_vchi(1, bsize_psi * ndim, npwx, npw, true);
210+
psi::Psi<std::complex<float>> f_batch_vhchi(1, bsize_psi * ndim, npwx, npw, true);
211211
std::vector<std::complex<float>> tmpj(ndim * allbands_sto * perbands_sto);
212212

213213
// 1. (<\psi|J|\chi>)^T
@@ -663,19 +663,19 @@ void Sto_EleCond::sKG(const int& smear_type,
663663
//-----------------------------------------------------------
664664
//------------------- allocate -------------------------
665665
size_t ks_memory_cost = perbands_ks * npwx * sizeof(std::complex<float>);
666-
psi::Psi<std::complex<double>> kspsi(1, perbands_ks, npwx, p_kv->ngk.data());
667-
psi::Psi<std::complex<double>> vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data());
666+
psi::Psi<std::complex<double>> kspsi(1, perbands_ks, npwx, npw, true);
667+
psi::Psi<std::complex<double>> vkspsi(1, perbands_ks * ndim, npwx, npw, true);
668668
std::vector<std::complex<double>> expmtmf_fact(perbands_ks), expmtf_fact(perbands_ks);
669-
psi::Psi<std::complex<float>> f_kspsi(1, perbands_ks, npwx, p_kv->ngk.data());
669+
psi::Psi<std::complex<float>> f_kspsi(1, perbands_ks, npwx, npw, true);
670670
ModuleBase::Memory::record("SDFT::kspsi", ks_memory_cost);
671-
psi::Psi<std::complex<float>> f_vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data());
671+
psi::Psi<std::complex<float>> f_vkspsi(1, perbands_ks * ndim, npwx, npw, true);
672672
ModuleBase::Memory::record("SDFT::vkspsi", ks_memory_cost);
673673
psi::Psi<std::complex<float>>* kspsi_all = &f_kspsi;
674674

675675
size_t sto_memory_cost = perbands_sto * npwx * sizeof(std::complex<double>);
676-
psi::Psi<std::complex<double>> sfchi(1, perbands_sto, npwx, p_kv->ngk.data());
676+
psi::Psi<std::complex<double>> sfchi(1, perbands_sto, npwx, npw, true);
677677
ModuleBase::Memory::record("SDFT::sfchi", sto_memory_cost);
678-
psi::Psi<std::complex<double>> smfchi(1, perbands_sto, npwx, p_kv->ngk.data());
678+
psi::Psi<std::complex<double>> smfchi(1, perbands_sto, npwx, npw, true);
679679
ModuleBase::Memory::record("SDFT::smfchi", sto_memory_cost);
680680
#ifdef __MPI
681681
psi::Psi<std::complex<float>> chi_all, hchi_all, psi_all;
@@ -702,8 +702,8 @@ void Sto_EleCond::sKG(const int& smear_type,
702702

703703
const int nbatch_psi = npart_sto;
704704
const int bsize_psi = ceil(double(perbands_sto) / nbatch_psi);
705-
psi::Psi<std::complex<double>> batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data());
706-
psi::Psi<std::complex<double>> batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data());
705+
psi::Psi<std::complex<double>> batch_vchi(1, bsize_psi * ndim, npwx, npw, true);
706+
psi::Psi<std::complex<double>> batch_vhchi(1, bsize_psi * ndim, npwx, npw, true);
707707
ModuleBase::Memory::record("SDFT::batchjpsi", 3 * bsize_psi * ndim * npwx * sizeof(std::complex<double>));
708708

709709
//------------------- sqrt(f)|psi> sqrt(1-f)|psi> ---------------
@@ -781,7 +781,7 @@ void Sto_EleCond::sKG(const int& smear_type,
781781
std::vector<std::complex<float>> j1r(ndim * dim_jmatrix), j2r(ndim * dim_jmatrix);
782782
ModuleBase::Memory::record("SDFT::j1r", sizeof(std::complex<float>) * ndim * dim_jmatrix);
783783
ModuleBase::Memory::record("SDFT::j2r", sizeof(std::complex<float>) * ndim * dim_jmatrix);
784-
psi::Psi<std::complex<double>> tmphchil(1, perbands_sto, npwx, p_kv->ngk.data());
784+
psi::Psi<std::complex<double>> tmphchil(1, perbands_sto, npwx, npw, true);
785785
ModuleBase::Memory::record("SDFT::tmphchil/r", sto_memory_cost * 2);
786786

787787
//------------------------ t loop --------------------------

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
248248
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
249249
{
250250
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, dmin, true);
251-
// psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0));
252251

253252
T* ppsi = psi_temp.get_pointer();
254253
syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size());

source/module_hsolver/hsolver_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
374374
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
375375
#endif
376376

377-
const int cur_nbasis = psi.get_ngk(psi.get_current_k());
377+
const int cur_nbasis = psi.get_current_nbas();
378378

379379
if (this->method == "cg")
380380
{

source/module_io/write_vxc_lip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ namespace ModuleIO
122122
// const ModuleBase::matrix vr_localxc = potxc->get_veff_smooth();
123123

124124
// 2. allocate xc operator
125-
psi::Psi<T> hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_ngk_pointer());
125+
psi::Psi<T> hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), kv.ngk, true);
126126
hpsi_localxc.zero_out();
127127
// std::cout << "hpsi.nk=" << hpsi_localxc.get_nk() << std::endl;
128128
// std::cout << "hpsi.nbands=" << hpsi_localxc.get_nbands() << std::endl;

source/module_psi/psi.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
6262
sizeof(T) * nk_in * nbd_in * nbs_in);
6363
}
6464

65+
66+
template <typename T, typename Device>
67+
Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector<int>& ngk_in, const bool k_first_in)
68+
{
69+
this->k_first = k_first_in;
70+
this->ngk = ngk_in.data();
71+
this->current_b = 0;
72+
this->current_k = 0;
73+
this->npol = PARAM.globalv.npol;
74+
75+
this->resize(nk_in, nbd_in, nbs_in);
76+
77+
// Currently only GPU's implementation is supported for device recording!
78+
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
79+
base_device::information::record_device_memory<Device>(this->ctx,
80+
GlobalV::ofs_device,
81+
"Psi->resize()",
82+
sizeof(T) * nk_in * nbd_in * nbs_in);
83+
}
84+
6585
// Constructor 8-1:
6686
template <typename T, typename Device>
6787
Psi<T, Device>::Psi(T* psi_pointer,
@@ -195,7 +215,7 @@ Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nban
195215
template <typename T, typename Device>
196216
Psi<T, Device>::Psi(const Psi& psi_in)
197217
{
198-
this->ngk = psi_in.get_ngk_pointer();
218+
this->ngk = psi_in.ngk;
199219
this->npol = psi_in.npol;
200220
this->nk = psi_in.get_nk();
201221
this->nbands = psi_in.get_nbands();
@@ -220,7 +240,7 @@ template <typename T, typename Device>
220240
template <typename T_in, typename Device_in>
221241
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
222242
{
223-
this->ngk = psi_in.get_ngk_pointer();
243+
this->ngk = psi_in.ngk;
224244
this->npol = psi_in.npol;
225245
this->nk = psi_in.get_nk();
226246
this->nbands = psi_in.get_nbands();
@@ -300,12 +320,6 @@ T* Psi<T, Device>::get_pointer(const int& ikb) const
300320
return this->psi_current + ikb * this->nbasis;
301321
}
302322

303-
template <typename T, typename Device>
304-
const int* Psi<T, Device>::get_ngk_pointer() const
305-
{
306-
return this->ngk;
307-
}
308-
309323
template <typename T, typename Device>
310324
const bool& Psi<T, Device>::get_k_first() const
311325
{

source/module_psi/psi.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Psi
4242
// Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later
4343
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
4444

45+
Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector<int>& ngk_in, const bool k_first_in);
46+
4547
// Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in
4648
Psi(const Psi& psi_in, const int nk_in, const int nband_in);
4749

@@ -121,12 +123,13 @@ class Psi
121123
int get_current_nbas() const;
122124

123125
const int& get_ngk(const int ik_in) const;
124-
// return ngk array of psi
125-
const int* get_ngk_pointer() const;
126+
126127
// return k_first
127128
const bool& get_k_first() const;
129+
128130
// return device type of psi
129131
const Device* get_device() const;
132+
130133
// return psi_bias
131134
const int& get_psi_bias() const;
132135

source/module_psi/test/psi_test.cpp

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,6 @@ TEST_F(TestPsi, get_val)
6363
EXPECT_EQ(psi_object14->get_psi_bias(), 0);
6464
}
6565

66-
// TEST_F(TestPsi, get_ngk)
67-
// {
68-
// psi::Psi<std::complex<double>>* psi_object21 = new psi::Psi<std::complex<double>>(&ngk[0]);
69-
// psi::Psi<double>* psi_object22 = new psi::Psi<double>(&ngk[0]);
70-
// psi::Psi<std::complex<float>>* psi_object23 = new psi::Psi<std::complex<float>>(&ngk[0]);
71-
// psi::Psi<float>* psi_object24 = new psi::Psi<float>(&ngk[0]);
72-
73-
// EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]);
74-
// EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]);
75-
76-
// EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]);
77-
// EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]);
78-
79-
// EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]);
80-
// EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]);
81-
82-
// EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]);
83-
// EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]);
84-
// }
85-
8666
TEST_F(TestPsi, get_pointer_op_zero_complex_double)
8767
{
8868
for (int i = 0; i < ink; i++)
@@ -331,30 +311,6 @@ TEST_F(TestPsi, band_first)
331311
EXPECT_EQ(std::get<0>(psi_band_32->to_range(illegal_range1)), nullptr);
332312
EXPECT_EQ(std::get<1>(psi_band_32->to_range(illegal_range2)), 0);
333313

334-
// pointer constructor
335-
// band-first to k-first
336-
// psi::Psi<float> psi_band_32_k(psi_band_32->get_pointer(), psi_band_32->get_nk(), psi_band_32->get_nbands(), psi_band_32->get_nbasis(), psi_band_32->get_ngk_pointer(), true);
337-
// k-first to band-first
338-
// psi::Psi<float> psi_band_32_b(psi_band_32_k.get_pointer(), psi_band_32_k.get_nk(), psi_band_32_k.get_nbands(), psi_band_32_k.get_nbasis(), psi_band_32_k.get_ngk_pointer(), false);
339-
// EXPECT_EQ(psi_band_32_k.get_nk(), ink);
340-
// EXPECT_EQ(psi_band_32_k.get_nbands(), inbands);
341-
// EXPECT_EQ(psi_band_32_k.get_nbasis(), inbasis);
342-
// EXPECT_EQ(psi_band_32_b.get_nk(), ink);
343-
// EXPECT_EQ(psi_band_32_b.get_nbands(), inbands);
344-
// EXPECT_EQ(psi_band_32_b.get_nbasis(), inbasis);
345-
// for (int ik = 0;ik < ink;++ik)
346-
// {
347-
// for (int ib = 0;ib < inbands;++ib)
348-
// {
349-
// psi_band_32->fix_kb(ik, ib);
350-
// psi_band_32_k.fix_kb(ik, ib);
351-
// psi_band_32_b.fix_kb(ik, ib);
352-
// EXPECT_EQ(psi_band_32->get_psi_bias(), (ib * ink + ik) * inbasis);
353-
// EXPECT_EQ(psi_band_32_k.get_psi_bias(), (ik * inbands + ib) * inbasis);
354-
// EXPECT_EQ(psi_band_32_b.get_psi_bias(), (ib * ink + ik) * inbasis);
355-
// }
356-
// }
357-
358314
delete psi_band_c64;
359315
delete psi_band_64;
360316
delete psi_band_c32;

0 commit comments

Comments
 (0)