Skip to content

Commit f4f958c

Browse files
committed
fix unit test
1 parent 7e44e9f commit f4f958c

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

source/module_io/test/write_wfc_nao_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class WriteWfcLcaoTest : public testing::Test
167167
TEST_F(WriteWfcLcaoTest, WriteWfcLcao)
168168
{
169169
// create a psi object
170-
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, true);
170+
psi::Psi<double> my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, nbasis_local, true);
171171
PARAM.sys.global_out_dir = "./";
172172
ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1);
173173

source/module_lr/utils/lr_util.hpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,19 @@ namespace LR_Util
104104

105105
/// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy
106106
template<typename T, typename Device>
107-
psi::Psi<T, Device> k1_to_bfirst_wrapper(const psi::Psi<T, Device>& psi_kfirst, int nk_in, int nbasis_in)
107+
psi::Psi<T, Device> c(const psi::Psi<T, Device>& psi_kfirst, int nk_in, int nbasis_in)
108108
{
109109
assert(psi_kfirst.get_nk() == 1);
110110
assert(nk_in * nbasis_in == psi_kfirst.get_nbasis());
111111

112112
int ib_now = psi_kfirst.get_current_b();
113113
psi_kfirst.fix_b(0); // for get_pointer() to get the head pointer
114-
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, false);
114+
psi::Psi<T, Device> psi_bfirst(psi_kfirst.get_pointer(),
115+
nk_in,
116+
psi_kfirst.get_nbands(),
117+
nbasis_in,
118+
nbasis_in,
119+
false);
115120
psi_kfirst.fix_b(ib_now);
116121
return psi_bfirst;
117122
}
@@ -124,7 +129,12 @@ namespace LR_Util
124129
int ik_now = psi_bfirst.get_current_k();
125130

126131
psi_bfirst.fix_kb(0, 0); // for get_pointer() to get the head pointer
127-
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), true);
132+
psi::Psi<T, Device> psi_kfirst(psi_bfirst.get_pointer(),
133+
1,
134+
psi_bfirst.get_nbands(),
135+
psi_bfirst.get_nk() * psi_bfirst.get_nbasis(),
136+
psi_bfirst.get_nk() * psi_bfirst.get_nbasis(),
137+
true);
128138
psi_bfirst.fix_kb(ik_now, ib_now);
129139
return psi_kfirst;
130140
}

0 commit comments

Comments
 (0)