Skip to content

Commit b3c528b

Browse files
committed
change test code for psi Constructor
1 parent 665a11a commit b3c528b

File tree

12 files changed

+86
-38
lines changed

12 files changed

+86
-38
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6464
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
6565

6666
// a "psi" with the bands of needed range
67-
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), nullptr, true);
67+
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);
6868

6969

7070
switch (op->get_act_type())

source/module_hsolver/test/diago_cg_float_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class DiagoCGPrepare
164164
auto psi_wrapper = psi::Psi<std::complex<float>>(
165165
psi_in.data<std::complex<float>>(), 1,
166166
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
167-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
167+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
168168
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
169169
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
170170
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<float>>());

source/module_hsolver/test/diago_cg_real_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class DiagoCGPrepare
167167
auto psi_wrapper = psi::Psi<double>(
168168
psi_in.data<double>(), 1,
169169
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
170-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
170+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
171171
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
172172
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
173173
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<double>());

source/module_hsolver/test/diago_cg_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class DiagoCGPrepare
158158
auto psi_wrapper = psi::Psi<std::complex<double>>(
159159
psi_in.data<std::complex<double>>(), 1,
160160
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
161-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
161+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
162162
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
163163
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
164164
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<double>>());

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class DiagoDavPrepare
113113
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
114114
const int ld_psi, const int nvec)
115115
{
116-
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, nullptr);
116+
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, true);
117117
psi::Range bands_range(true, 0, 0, nvec-1);
118118
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
119119
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class DiagoDavPrepare
112112
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
113113
const int ld_psi, const int nvec)
114114
{
115-
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, nullptr);
115+
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, true);
116116
psi::Range bands_range(true, 0, 0, nvec-1);
117117
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
118118
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class DiagoDavPrepare
115115
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
116116
const int ld_psi, const int nvec)
117117
{
118-
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, nullptr);
118+
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, true);
119119
psi::Range bands_range(true, 0, 0, nvec-1);
120120
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
121121
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_io/test/write_wfc_nao_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class WriteWfcLcaoTest : public testing::Test
167167
TEST_F(WriteWfcLcaoTest, WriteWfcLcao)
168168
{
169169
// create a psi object
170-
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local);
170+
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, true);
171171
PARAM.sys.global_out_dir = "./";
172172
ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1);
173173

@@ -196,7 +196,7 @@ TEST_F(WriteWfcLcaoTest, WriteWfcLcao)
196196

197197
TEST_F(WriteWfcLcaoTest, WriteWfcLcaoComplex)
198198
{
199-
psi::Psi<std::complex<double>> my_psi(psi_local_complex.data(), nk, nbands_local, nbasis_local);
199+
psi::Psi<std::complex<double>> my_psi(psi_local_complex.data(), nk, nbands_local, nbasis_local, true);
200200
PARAM.sys.global_out_dir = "./";
201201
ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1);
202202

source/module_lr/utils/lr_util.hpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,16 @@ namespace LR_Util
108108
{
109109
assert(psi_kfirst.get_nk() == 1);
110110
assert(nk_in * nbasis_in == psi_kfirst.get_nbasis());
111+
112+
std::vector<int> ngk_vector(nk_in, 0);
113+
for (size_t i = 0; i < ngk_vector.size(); i++)
114+
{
115+
ngk_vector[i] = psi_kfirst.get_ngk_pointer()[i];
116+
}
117+
111118
int ib_now = psi_kfirst.get_current_b();
112119
psi_kfirst.fix_b(0); // for get_pointer() to get the head pointer
113-
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, psi_kfirst.get_ngk_pointer(), false);
120+
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, ngk_vector, false);
114121
psi_kfirst.fix_b(ib_now);
115122
return psi_bfirst;
116123
}
@@ -121,8 +128,15 @@ namespace LR_Util
121128
{
122129
int ib_now = psi_bfirst.get_current_b();
123130
int ik_now = psi_bfirst.get_current_k();
131+
132+
std::vector<int> ngk_vector(psi_bfirst.get_nk(), 0);
133+
for (size_t i = 0; i < ngk_vector.size(); i++)
134+
{
135+
ngk_vector[i] = psi_bfirst.get_ngk_pointer()[i];
136+
}
137+
124138
psi_bfirst.fix_kb(0, 0); // for get_pointer() to get the head pointer
125-
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), psi_bfirst.get_ngk_pointer(), true);
139+
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), ngk_vector, true);
126140
psi_bfirst.fix_kb(ik_now, ib_now);
127141
return psi_kfirst;
128142
}

source/module_psi/psi.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,41 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
6969
sizeof(T) * nk_in * nbd_in * nbs_in);
7070
}
7171

72+
// Constructor 8-1:
73+
// template <typename T, typename Device>
74+
// Psi<T, Device>::Psi(T* psi_pointer,
75+
// const int nk_in,
76+
// const int nbd_in,
77+
// const int nbs_in,
78+
// const int* ngk_in,
79+
// const bool k_first_in)
80+
// {
81+
// this->k_first = k_first_in;
82+
// this->ngk = ngk_in;
83+
// this->current_b = 0;
84+
// this->current_k = 0;
85+
// this->npol = PARAM.globalv.npol;
86+
// this->device = base_device::get_device_type<Device>(this->ctx);
87+
// this->nk = nk_in;
88+
// this->nbands = nbd_in;
89+
// this->nbasis = nbs_in;
90+
// this->current_nbasis = nbs_in;
91+
// this->psi_current = this->psi = psi_pointer;
92+
// this->allocate_inside = false;
93+
// // Currently only GPU's implementation is supported for device recording!
94+
// base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
95+
// }
96+
97+
// Constructor 8-3:
7298
template <typename T, typename Device>
7399
Psi<T, Device>::Psi(T* psi_pointer,
74100
const int nk_in,
75101
const int nbd_in,
76102
const int nbs_in,
77-
const int* ngk_in,
78103
const bool k_first_in)
79104
{
80105
this->k_first = k_first_in;
81-
this->ngk = ngk_in;
106+
this->ngk = nullptr;
82107
this->current_b = 0;
83108
this->current_k = 0;
84109
this->npol = PARAM.globalv.npol;
@@ -93,7 +118,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
93118
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
94119
}
95120

96-
121+
// Constructor 8-2:
97122
template <typename T, typename Device>
98123
Psi<T, Device>::Psi(T* psi_pointer,
99124
const int nk_in,

0 commit comments

Comments
 (0)