Skip to content

Commit fcc167f

Browse files
committed
fix bug
1 parent 30b1aa4 commit fcc167f

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
8989
psi_input->npol,
9090
tmpsi_in,
9191
this->hpsi->get_pointer(),
92-
// psi_input->get_ngk(op->ik),
93-
psi_input->get_current_nbas(),
92+
psi_input->get_ngk(op->ik),
93+
// psi_input->get_current_nbas(),
9494
is_first_node);
9595
break;
9696
}

source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,20 @@ void Velocity::act
4848
) const
4949
{
5050
ModuleBase::timer::tick("Operator", "Velocity");
51-
// const int npw = psi_in->get_ngk(this->ik);
52-
const int npw = psi_in->get_current_nbas();
51+
52+
// if (psi_in->get_ngk(this->ik) != psi_in->get_current_nbas())
53+
// {
54+
// std::cout << "op->ik : " << this->ik << std::endl;
55+
// std::cout << "get_ngk(op->ik) : " << psi_in->get_ngk(this->ik) << std::endl;
56+
// std::cout << "get_current_nbas() : " << psi_in->get_current_nbas() << std::endl;
57+
58+
// std::cout << "ik : " << this->ik << std::endl;
59+
// }
60+
61+
62+
const int npw = psi_in->get_ngk(this->ik);
63+
// const int npw = psi_in->get_current_nbas();
64+
5365
const int max_npw = psi_in->get_nbasis() / psi_in->npol;
5466
const int npol = psi_in->npol;
5567
const std::complex<double>* tmpsi_in = psi0;

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
199199

200200
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
201201
{
202-
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, dmin, true);
202+
psi::Psi<T, Device> psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true);
203203
T* ppsi = psi_temp.get_pointer();
204204
// hpsi and spsi share the temp space
205205
T* temp = nullptr;
@@ -246,7 +246,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
246246
}
247247
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
248248
{
249-
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, dmin, true);
249+
psi::Psi<T, Device> psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true);
250250
T* ppsi = psi_temp.get_pointer();
251251
syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size());
252252
// hpsi and spsi share the temp space

source/module_psi/psi.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ template <typename T, typename Device>
118118
Psi<T, Device>::Psi(const int nk_in,
119119
const int nbd_in,
120120
const int nbs_in,
121+
const int* ngk_in,
121122
const int current_nbasis_in,
122123
const bool k_first_in)
123124
{
@@ -129,7 +130,7 @@ Psi<T, Device>::Psi(const int nk_in,
129130
this->npol = PARAM.globalv.npol;
130131
this->allocate_inside = true;
131132

132-
this->ngk = nullptr;
133+
this->ngk = ngk_in;
133134
assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0);
134135
resize_memory_op()(this->ctx, this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
135136

source/module_psi/psi.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ class Psi
7373

7474

7575
// Constructor 8-3: 2D Psi version 3
76-
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in);
76+
Psi(const int nk_in,
77+
const int nbd_in,
78+
const int nbs_in,
79+
const int* ngk_in,
80+
const int current_nbasis_in,
81+
const bool k_first_in);
7782

7883

7984
// Destructor for deleting the psi array manually

0 commit comments

Comments
 (0)