Skip to content

Commit 7e44e9f

Browse files
committed
check 1
1 parent 6ff1b3a commit 7e44e9f

File tree

5 files changed

+31
-68
lines changed

5 files changed

+31
-68
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,12 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6666

6767
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
6868
// a "psi" with the bands of needed range
69-
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);
70-
71-
72-
// if (psi_input->get_ngk(op->ik) != psi_input->get_current_nbas())
73-
// {
74-
// std::cout << "op->ik : " << op->ik << std::endl;
75-
// std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl;
76-
// std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl;
77-
78-
// std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl;
79-
// }
69+
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in),
70+
1,
71+
nbands,
72+
psi_input->get_nbasis(),
73+
psi_input->get_nbasis(),
74+
true);
8075

8176
switch (op->get_act_type())
8277
{
@@ -89,7 +84,6 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
8984
psi_input->npol,
9085
tmpsi_in,
9186
this->hpsi->get_pointer(),
92-
// psi_input->get_ngk(op->ik),
9387
psi_input->get_current_nbas(),
9488
is_first_node);
9589
break;

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
199199

200200
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
201201
{
202-
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true);
203-
// psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ngk(0));
202+
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, dmin, true);
203+
204204
T* ppsi = psi_temp.get_pointer();
205205
// hpsi and spsi share the temp space
206206
T* temp = nullptr;
@@ -247,7 +247,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
247247
}
248248
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
249249
{
250-
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true);
250+
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, dmin, true);
251251
// psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0));
252252

253253
T* ppsi = psi_temp.get_pointer();

source/module_hsolver/hsolver_pw.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310
#endif
311311

312312
/// solve eigenvector and eigenvalue for H(k)
313-
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks);
313+
this->hamiltSolvePsiK(pHamilt,
314+
psi,
315+
precondition,
316+
eigenvalues.data() + ik * psi.get_nbands(),
317+
this->wfc_basis->nks);
314318

315319
if (skip_charge)
316320
{
@@ -370,20 +374,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
370374
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
371375
#endif
372376

373-
auto ngk_pointer = psi.get_ngk_pointer();
374-
375-
std::vector<int> ngk_vector(nk_nums, 0);
376-
for (int i = 0; i < nk_nums; i++)
377-
{
378-
ngk_vector[i] = ngk_pointer[i];
379-
}
380-
381377
const int cur_nbasis = psi.get_ngk(psi.get_current_k());
382378

383379
if (this->method == "cg")
384380
{
385381
// wrap the subspace_func into a lambda function
386-
auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
382+
auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
387383
// psi_in should be a 2D tensor:
388384
// psi_in.shape() = [nbands, nbasis]
389385
const auto ndim = psi_in.shape().ndim();
@@ -393,13 +389,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
393389
1,
394390
psi_in.shape().dim_size(0),
395391
psi_in.shape().dim_size(1),
396-
ngk_vector,
397392
cur_nbasis);
398393
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
399394
1,
400395
psi_out.shape().dim_size(0),
401396
psi_out.shape().dim_size(1),
402-
ngk_vector,
403397
cur_nbasis);
404398
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
405399
ct::DeviceType::CpuDevice,
@@ -419,7 +413,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
419413
using ct_Device = typename ct::PsiToContainer<Device>::type;
420414

421415
// wrap the hpsi_func and spsi_func into a lambda function
422-
auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
416+
auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
423417
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
424418
// psi_in should be a 2D tensor:
425419
// psi_in.shape() = [nbands, nbasis]
@@ -430,7 +424,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
430424
1,
431425
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
432426
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
433-
ngk_vector,
434427
cur_nbasis);
435428
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
436429
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
@@ -491,11 +484,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
491484
const int nband = psi.get_nbands();
492485
const int nbasis = psi.get_nbasis();
493486
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
494-
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
487+
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
495488
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
496489

497490
// Convert "pointer data stucture" to a psi::Psi object
498-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
491+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
499492

500493
psi::Range bands_range(true, 0, 0, nvec - 1);
501494

@@ -512,11 +505,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
512505
else if (this->method == "dav_subspace")
513506
{
514507
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
515-
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
508+
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
516509
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
517510

518511
// Convert "pointer data stucture" to a psi::Psi object
519-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
512+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
520513

521514
psi::Range bands_range(true, 0, 0, nvec - 1);
522515

@@ -557,17 +550,17 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
557550

558551
// dimensions of matrix to be solved
559552
const int dim = psi.get_cur_effective_basis(); /// dimension of matrix
560-
const int nband = psi.get_nbands(); /// number of eigenpairs sought
561-
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
553+
const int nband = psi.get_nbands(); /// number of eigenpairs sought
554+
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
562555

563556
// Davidson matrix-blockvector functions
564557
/// wrap hpsi into lambda function, Matrix \times blockvector
565558
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
566-
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
559+
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
567560
ModuleBase::timer::tick("David", "hpsi_func");
568561

569562
// Convert pointer of psi_in to a psi::Psi object
570-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
563+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
571564

572565
psi::Range bands_range(true, 0, 0, nvec - 1);
573566

source/module_psi/psi.cpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,18 @@ Psi<T, Device>::Psi(T* psi_pointer,
6868
const int nk_in,
6969
const int nbd_in,
7070
const int nbs_in,
71-
const std::vector<int>& ngk_vector_in,
7271
const int current_nbasis_in,
7372
const bool k_first_in)
7473
{
74+
75+
// Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
76+
assert(nk_in == 1);
77+
7578
this->k_first = k_first_in;
7679
this->npol = PARAM.globalv.npol;
7780
this->allocate_inside = false;
7881

79-
this->ngk = ngk_vector_in.data();
82+
this->ngk = nullptr;
8083

8184
this->psi = psi_pointer;
8285

@@ -94,31 +97,11 @@ Psi<T, Device>::Psi(T* psi_pointer,
9497
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
9598
}
9699

97-
// Constructor 8-2:
98-
template <typename T, typename Device>
99-
Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in)
100-
{
101-
this->k_first = k_first_in;
102-
this->ngk = nullptr;
103-
this->current_b = 0;
104-
this->current_k = 0;
105-
this->npol = PARAM.globalv.npol;
106-
this->nk = nk_in;
107-
this->nbands = nbd_in;
108-
this->nbasis = nbs_in;
109-
this->current_nbasis = nbs_in;
110-
this->psi_current = this->psi = psi_pointer;
111-
this->allocate_inside = false;
112-
// Currently only GPU's implementation is supported for device recording!
113-
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
114-
}
115-
116100
// Constructor 8-3: 2D Psi version 3
117101
template <typename T, typename Device>
118102
Psi<T, Device>::Psi(const int nk_in,
119103
const int nbd_in,
120104
const int nbs_in,
121-
const int* ngk_in,
122105
const int current_nbasis_in,
123106
const bool k_first_in)
124107
{
@@ -130,15 +113,15 @@ Psi<T, Device>::Psi(const int nk_in,
130113
this->npol = PARAM.globalv.npol;
131114
this->allocate_inside = true;
132115

133-
this->ngk = ngk_in;
116+
this->ngk = nullptr;
134117
assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0);
135118
resize_memory_op()(this->ctx, this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
136119

137120
this->nk = nk_in;
138121
this->nbands = nbd_in;
139122
this->nbasis = nbs_in;
140123

141-
this->current_k = 0;
124+
this->current_k = 0;
142125
this->current_b = 0;
143126
this->current_nbasis = current_nbasis_in;
144127
this->psi_current = this->psi;

source/module_psi/psi.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,13 @@ class Psi
6363
const int nk_in,
6464
const int nbd_in,
6565
const int nbs_in,
66-
const std::vector<int>& ngk_vector_in,
6766
const int current_nbasis_in,
6867
const bool k_first_in = true);
6968

70-
// Constructor 8-2: a pointer version of constructor 3
71-
// only used in operator.cpp call_act func
72-
Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in);
73-
74-
7569
// Constructor 8-3: 2D Psi version 3
7670
Psi(const int nk_in,
7771
const int nbd_in,
7872
const int nbs_in,
79-
const int* ngk_in,
8073
const int current_nbasis_in,
8174
const bool k_first_in);
8275

0 commit comments

Comments
 (0)