Skip to content

Commit c2cb0df

Browse files
committed
update Constructor in psi
1 parent 0906e22 commit c2cb0df

File tree

8 files changed

+34
-12
lines changed

8 files changed

+34
-12
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10801080
//! initialize the gradients of Etotal with respect to occupation numbers and wfc,
10811081
//! and set all elements to 0.
10821082
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
1083-
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
1083+
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis(), this->kv.ngk, true);
10841084
dE_dWfc.zero_out();
10851085

10861086
double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc);

source/module_esolver/esolver_of.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell)
220220

221221
// Refresh the arrays
222222
delete this->psi_;
223-
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
223+
this->psi_ = new psi::Psi<double>(1,
224+
PARAM.inp.nspin,
225+
this->pw_rho->nrxx,
226+
this->pw_rho->nrxx,
227+
true);
224228
for (int is = 0; is < PARAM.inp.nspin; ++is)
225229
{
226230
this->pphi_[is] = this->psi_->get_pointer(is);

source/module_esolver/esolver_of_tool.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ void ESolver_OF::init_elecstate(UnitCell& ucell)
7171
void ESolver_OF::allocate_array()
7272
{
7373
// Initialize the "wavefunction", which is sqrt(rho)
74-
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
74+
this->psi_ = new psi::Psi<double>(1,
75+
PARAM.inp.nspin,
76+
this->pw_rho->nrxx,
77+
this->pw_rho->nrxx,
78+
true);
7579
ModuleBase::Memory::record("OFDFT::Psi", sizeof(double) * PARAM.inp.nspin * this->pw_rho->nrxx);
7680
this->pphi_ = new double*[PARAM.inp.nspin];
7781
for (int is = 0; is < PARAM.inp.nspin; ++is)

source/module_io/get_pchg_lcao.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,12 @@ void IState_Charge::idmatrix(const int& ib,
541541
}
542542

543543
this->psi_k->fix_k(ik);
544-
// psi::Psi<std::complex<double>> wg_wfc(*this->psi_k, 1);
545-
psi::Psi<std::complex<double>> wg_wfc(1, this->psi_k->get_nbands(), this->psi_k->get_nbasis());
544+
545+
psi::Psi<std::complex<double>> wg_wfc(1,
546+
this->psi_k->get_nbands(),
547+
this->psi_k->get_nbasis(),
548+
this->psi_k->get_nbasis(),
549+
true);
546550

547551
for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir)
548552
{

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,11 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
181181
if (this->nbands == PARAM.inp.nbands) { move_gs(); }
182182
else // copy the part of ground state info according to paraC_
183183
{
184-
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(), this->paraC_.get_col_size(), this->paraC_.get_row_size());
184+
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(),
185+
this->paraC_.get_col_size(),
186+
this->paraC_.get_row_size(),
187+
this->kv.ngk,
188+
true);
185189
this->eig_ks.create(this->kv.get_nks(), this->nbands);
186190
const int start_band = this->nocc_max - *std::max_element(nocc.begin(), nocc.end());
187191
for (int ik = 0;ik < this->kv.get_nks();++ik)
@@ -289,8 +293,10 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
289293
// now ModuleIO::read_wfc_nao needs `Parallel_Orbitals` and can only read all the bands
290294
// it need improvement to read only the bands needed
291295
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(),
292-
this->paraMat_.ncol_bands,
293-
this->paraMat_.get_row_size());
296+
this->paraMat_.ncol_bands,
297+
this->paraMat_.get_row_size(),
298+
this->kv.ngk,
299+
true);
294300
this->read_ks_wfc();
295301
if (nspin == 2)
296302
{

source/module_lr/hamilt_casida.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ namespace LR
1818
{//calculate A^{ai} for each bj
1919
int bj = j * nv + b; //global
2020
int kbj = ik * npairs + bj; //global
21-
psi::Psi<T> X_bj(1, 1, this->nk * px.get_local_size()); // k1-first, like in iterative solver
21+
psi::Psi<T> X_bj(1, 1, this->nk * px.get_local_size(), this->nk * px.get_local_size(), true); // k1-first, like in iterative solver
2222
X_bj.zero_out();
2323
// X_bj(0, 0, lj * px.get_row_size() + lb) = this->one();
2424
int lj = px.global2local_col(j);
2525
int lb = px.global2local_row(b);
2626
if (px.in_this_processor(b, j)) { X_bj(0, 0, ik * px.get_local_size() + lj * px.get_row_size() + lb) = this->one(); }
27-
psi::Psi<T> A_aibj(1, 1, this->nk * px.get_local_size()); // k1-first
27+
psi::Psi<T> A_aibj(1,
28+
1,
29+
this->nk * px.get_local_size(),
30+
this->nk * px.get_local_size(),
31+
true); // k1-first
2832
A_aibj.zero_out();
2933

3034
this->cal_dm_trans(0, X_bj.get_pointer());

source/module_psi/psi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
9393
{
9494

9595
// Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
96-
// assert(nk_in == 1);
96+
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func
9797

9898
this->k_first = k_first_in;
9999
this->npol = PARAM.globalv.npol;

source/module_psi/psi.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Psi
4040
Psi();
4141

4242
// Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later
43-
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
43+
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true);
4444

4545
Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector<int>& ngk_in, const bool k_first_in);
4646

0 commit comments

Comments
 (0)