Skip to content

Commit 2def09e

Browse files
committed
remove Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
1 parent 49987f8 commit 2def09e

File tree

5 files changed

+40
-28
lines changed

5 files changed

+40
-28
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6161
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
6262
syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
6363
delete this->hpsi;
64-
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
64+
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
65+
1,
66+
nbands / psi_input->npol,
67+
psi_input->get_nbasis(),
68+
psi_input->get_nbasis(),
69+
true);
6570
}
6671

6772
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
@@ -177,7 +182,13 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
177182
else
178183
{
179184
this->in_place = false;
180-
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);
185+
// this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);
186+
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
187+
1,
188+
nbands_range,
189+
std::get<0>(info)->get_nbasis(),
190+
std::get<0>(info)->get_nbasis(),
191+
true);
181192
}
182193

183194
hpsi_pointer = this->hpsi->get_pointer();

source/module_lr/utils/lr_util.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ namespace LR_Util
9999
template<typename T>
100100
psi::Psi<T> get_psi_spin(const psi::Psi<T>& psi_in, const int& is, const int& nk)
101101
{
102-
return psi::Psi<T>(&psi_in(is * nk, 0, 0), psi_in, nk, psi_in.get_nbands());
102+
return psi::Psi<T>(&psi_in(is * nk, 0, 0),
103+
nk,
104+
psi_in.get_nbands(),
105+
psi_in.get_nbasis(),
106+
true);
103107
}
104108

105109
/// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy

source/module_psi/psi.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
9393
{
9494

9595
// Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
96-
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func
96+
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
9797

9898
this->k_first = k_first_in;
9999
this->npol = PARAM.globalv.npol;
@@ -193,24 +193,24 @@ Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, const int nband_in)
193193
}
194194
}
195195

196-
template <typename T, typename Device>
197-
Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
198-
{
199-
this->k_first = psi_in.get_k_first();
200-
assert(nk_in <= psi_in.get_nk());
201-
if (nband_in == 0)
202-
{
203-
nband_in = psi_in.get_nbands();
204-
}
205-
this->ngk = psi_in.ngk;
206-
this->npol = psi_in.npol;
207-
this->nk = nk_in;
208-
this->nbands = nband_in;
209-
this->nbasis = psi_in.nbasis;
210-
this->psi_current = psi_pointer;
211-
this->allocate_inside = false;
212-
this->psi = psi_pointer;
213-
}
196+
// template <typename T, typename Device>
197+
// Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
198+
// {
199+
// this->k_first = psi_in.get_k_first();
200+
// assert(nk_in <= psi_in.get_nk());
201+
// if (nband_in == 0)
202+
// {
203+
// nband_in = psi_in.get_nbands();
204+
// }
205+
// this->ngk = psi_in.ngk;
206+
// this->npol = psi_in.npol;
207+
// this->nk = nk_in;
208+
// this->nbands = nband_in;
209+
// this->nbasis = psi_in.nbasis;
210+
// this->psi_current = psi_pointer;
211+
// this->allocate_inside = false;
212+
// this->psi = psi_pointer;
213+
// }
214214

215215
template <typename T, typename Device>
216216
Psi<T, Device>::Psi(const Psi& psi_in)

source/module_psi/psi.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class Psi
4747
// Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in
4848
Psi(const Psi& psi_in, const int nk_in, const int nband_in);
4949

50-
// Constructor 5: a wrapper of a data pointer, used for Operator::hPsi()
51-
// in this case, fix_k can not be used
52-
Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0);
50+
// // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi()
51+
// // in this case, fix_k can not be used
52+
// Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in);
5353

5454
// Constructor 6: initialize a new psi from the given psi_in
5555
Psi(const Psi& psi_in);

source/module_psi/test/psi_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ class TestPsi : public ::testing::Test
1414
const psi::Psi<double>* psi_object32 = new psi::Psi<double>(ink, inbands, inbasis, &ngk[0]);
1515
const psi::Psi<std::complex<float>>* psi_object33 = new psi::Psi<std::complex<float>>(ink, inbands, inbasis, &ngk[0]);
1616
const psi::Psi<float>* psi_object34 = new psi::Psi<float>(ink, inbands, inbasis, &ngk[0]);
17-
18-
// psi::Psi<std::complex<double>>* psi_object4 = new psi::Psi<std::complex<double>>(*psi_object31, ink, 0);
19-
psi::Psi<std::complex<double>>* psi_object5 = new psi::Psi<std::complex<double>>(psi_object31->get_pointer(), *psi_object31, ink, 0);
2017
};
2118

2219
TEST_F(TestPsi, get_val)

0 commit comments

Comments
 (0)