Skip to content

Commit 1fb8851

Browse files
committed
fix bug about ngk
1 parent 1c2f523 commit 1fb8851

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6565

6666
// a "psi" with the bands of needed range
6767
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);
68+
69+
// std::cout << "op->ik : " << op->ik << std::endl;
70+
// std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl;
71+
// std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl;
72+
6873

6974

7075
switch (op->get_act_type())

source/module_hsolver/hsolver_pw.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
378378
ngk_vector[i] = ngk_pointer[i];
379379
}
380380

381+
const int cur_nbasis = psi.get_current_nbas();
382+
381383
if (this->method == "cg")
382384
{
383385
// wrap the subspace_func into a lambda function
384-
auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
386+
auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
385387
// psi_in should be a 2D tensor:
386388
// psi_in.shape() = [nbands, nbasis]
387389
const auto ndim = psi_in.shape().ndim();
@@ -391,12 +393,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
391393
1,
392394
psi_in.shape().dim_size(0),
393395
psi_in.shape().dim_size(1),
394-
ngk_vector);
396+
ngk_vector,
397+
cur_nbasis);
395398
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
396399
1,
397400
psi_out.shape().dim_size(0),
398401
psi_out.shape().dim_size(1),
399-
ngk_vector);
402+
ngk_vector,
403+
cur_nbasis);
400404
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
401405
ct::DeviceType::CpuDevice,
402406
ct::TensorShape({psi_in.shape().dim_size(0)}));
@@ -415,7 +419,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
415419
using ct_Device = typename ct::PsiToContainer<Device>::type;
416420

417421
// wrap the hpsi_func and spsi_func into a lambda function
418-
auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
422+
auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
419423
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
420424
// psi_in should be a 2D tensor:
421425
// psi_in.shape() = [nbands, nbasis]
@@ -426,7 +430,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
426430
1,
427431
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
428432
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
429-
ngk_vector);
433+
ngk_vector,
434+
cur_nbasis);
430435
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
431436
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
432437
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
@@ -486,11 +491,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
486491
const int nband = psi.get_nbands();
487492
const int nbasis = psi.get_nbasis();
488493
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
489-
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
494+
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
490495
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
491496

492497
// Convert "pointer data stucture" to a psi::Psi object
493-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
498+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
494499

495500
psi::Range bands_range(true, 0, 0, nvec - 1);
496501

@@ -507,11 +512,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
507512
else if (this->method == "dav_subspace")
508513
{
509514
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
510-
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
515+
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
511516
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
512517

513518
// Convert "pointer data stucture" to a psi::Psi object
514-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
519+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
515520

516521
psi::Range bands_range(true, 0, 0, nvec - 1);
517522

@@ -558,11 +563,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
558563
// Davidson matrix-blockvector functions
559564
/// wrap hpsi into lambda function, Matrix \times blockvector
560565
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
561-
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
566+
auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
562567
ModuleBase::timer::tick("David", "hpsi_func");
563568

564569
// Convert pointer of psi_in to a psi::Psi object
565-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
570+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis);
566571

567572
psi::Range bands_range(true, 0, 0, nvec - 1);
568573

source/module_psi/psi.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
6969
const int nbd_in,
7070
const int nbs_in,
7171
const std::vector<int>& ngk_vector_in,
72+
const int current_nbasis_in,
7273
const bool k_first_in)
7374
{
7475
this->k_first = k_first_in;
@@ -79,7 +80,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
7980
this->nk = nk_in;
8081
this->nbands = nbd_in;
8182
this->nbasis = nbs_in;
82-
this->current_nbasis = nbs_in;
83+
this->current_nbasis = current_nbasis_in;
8384
this->psi_current = this->psi = psi_pointer;
8485
this->allocate_inside = false;
8586
// Currently only GPU's implementation is supported for device recording!

source/module_psi/psi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Psi
6464
const int nbd_in,
6565
const int nbs_in,
6666
const std::vector<int>& ngk_vector_in,
67+
const int current_nbasis_in,
6768
const bool k_first_in = true);
6869

6970
// Constructor 8-2: a pointer version of constructor 3

0 commit comments

Comments
 (0)