Skip to content

Commit 8935299

Browse files
committed
debug unit test
1 parent 5a86f45 commit 8935299

File tree

5 files changed

+42
-19
lines changed

5 files changed

+42
-19
lines changed

source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ TEST_F(TestStoTool, parallel_distribution)
6868

6969
TEST_F(TestStoTool, convert_psi)
7070
{
71-
psi::Psi<std::complex<double>> psi_in(1, 1, 10);
72-
psi::Psi<std::complex<float>> psi_out(1, 1, 10);
71+
psi::Psi<std::complex<double>> psi_in(1, 1, 10, 10, true);
72+
psi::Psi<std::complex<float>> psi_out(1, 1, 10, 10, true);
7373
for (int i = 0; i < 10; ++i)
7474
{
7575
psi_in.get_pointer()[i] = std::complex<double>(i, i);
@@ -83,8 +83,8 @@ TEST_F(TestStoTool, convert_psi)
8383

8484
TEST_F(TestStoTool, gatherchi)
8585
{
86-
psi::Psi<std::complex<float>> chi(1, 1, 10);
87-
psi::Psi<std::complex<float>> chi_all(1, 1, 10);
86+
psi::Psi<std::complex<float>> chi(1, 1, 10, 10, true);
87+
psi::Psi<std::complex<float>> chi_all(1, 1, 10, 10, true);
8888
int npwx = 10;
8989
int nrecv_sto[4] = {1, 2, 3, 4};
9090
int displs_sto[4] = {0, 1, 3, 6};

source/module_lr/AX/test/AX_test.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ TEST_F(AXTest, DoubleSerial)
7070
int size_v = s.naos * s.naos;
7171
for (int istate = 0;istate < nstate;++istate)
7272
{
73-
psi::Psi<double> c(s.nks, s.nocc + s.nvirt, s.naos);
73+
std::vector<int> temp(s.nks, s.naos);
74+
psi::Psi<double> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true);
7475
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
7576
set_rand(&c(0, 0, 0), size_c);
7677
for (auto& v : V) { set_rand(v.data<double>(), size_v); }
@@ -91,7 +92,8 @@ TEST_F(AXTest, ComplexSerial)
9192
int size_v = s.naos * s.naos;
9293
for (int istate = 0;istate < nstate;++istate)
9394
{
94-
psi::Psi<std::complex<double>> c(s.nks, s.nocc + s.nvirt, s.naos);
95+
std::vector<int> temp(s.nks, s.naos);
96+
psi::Psi<std::complex<double>> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true);
9597
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
9698
set_rand(&c(0, 0, 0), size_c);
9799
for (auto& v : V) { set_rand(v.data<std::complex<double>>(), size_v); }
@@ -113,7 +115,9 @@ TEST_F(AXTest, DoubleParallel)
113115
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() }));
114116
Parallel_2D pc;
115117
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt);
116-
psi::Psi<double> c(s.nks, pc.get_col_size(), pc.get_row_size());
118+
119+
std::vector<int> ngk_temp(s.nks, pc.get_row_size());
120+
psi::Psi<double> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true);
117121
Parallel_2D px;
118122
LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt);
119123

@@ -139,7 +143,9 @@ TEST_F(AXTest, DoubleParallel)
139143
}
140144
// compare to global AX
141145
std::vector<container::Tensor> V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
142-
psi::Psi<double> c_full(s.nks, s.nocc + s.nvirt, s.naos);
146+
147+
std::vector<int> ngk_temp_1(s.nks, s.naos);
148+
psi::Psi<double> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true);
143149
for (int isk = 0;isk < s.nks;++isk)
144150
{
145151
LR_Util::gather_2d_to_full(pV, V.at(isk).data<double>(), V_full.at(isk).data<double>(), false, s.naos, s.naos);
@@ -165,7 +171,9 @@ TEST_F(AXTest, ComplexParallel)
165171
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() }));
166172
Parallel_2D pc;
167173
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt);
168-
psi::Psi<std::complex<double>> c(s.nks, pc.get_col_size(), pc.get_row_size());
174+
175+
std::vector<int> ngk_temp_1(s.nks, pc.get_row_size());
176+
psi::Psi<std::complex<double>> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true);
169177
Parallel_2D px;
170178
LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt);
171179

@@ -187,7 +195,10 @@ TEST_F(AXTest, ComplexParallel)
187195
}
188196
// compare to global AX
189197
std::vector<container::Tensor> V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
190-
psi::Psi<std::complex<double>> c_full(s.nks, s.nocc + s.nvirt, s.naos);
198+
199+
200+
std::vector<int> ngk_temp_2(s.nks, s.naos);
201+
psi::Psi<std::complex<double>> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true);
191202
for (int isk = 0;isk < s.nks;++isk)
192203
{
193204
LR_Util::gather_2d_to_full(pV, V.at(isk).data<std::complex<double>>(), V_full.at(isk).data<std::complex<double>>(), false, s.naos, s.naos);

source/module_lr/dm_trans/test/dm_trans_test.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ TEST_F(DMTransTest, DoubleSerial)
6666
for (int istate = 0;istate < nstate;++istate)
6767
{
6868
int size_c = s.nks * (s.nocc + s.nvirt) * s.naos;
69-
psi::Psi<double> c(s.nks, s.nocc + s.nvirt, s.naos);
69+
70+
std::vector<int> temp(s.nks, s.naos);
71+
psi::Psi<double> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true);
7072
set_rand(c.get_pointer(), size_c);
7173
X.fix_b(istate);
7274
const std::vector<container::Tensor>& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt);
@@ -85,7 +87,9 @@ TEST_F(DMTransTest, ComplexSerial)
8587
for (int istate = 0;istate < nstate;++istate)
8688
{
8789
int size_c = s.nks * (s.nocc + s.nvirt) * s.naos;
88-
psi::Psi<std::complex<double>> c(s.nks, s.nocc + s.nvirt, s.naos);
90+
91+
std::vector<int> temp(s.nks, s.naos);
92+
psi::Psi<std::complex<double>> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true);
8993
set_rand(c.get_pointer(), size_c);
9094
X.fix_b(istate);
9195
const std::vector<container::Tensor>& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt);
@@ -105,10 +109,14 @@ TEST_F(DMTransTest, DoubleParallel)
105109
// X: nvirt*nocc in para2d, nocc*nvirt in psi (row-para and constructed: nvirt)
106110
Parallel_2D px;
107111
LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc);
108-
psi::Psi<double> X(s.nks, nstate, px.get_local_size(), nullptr, false);
112+
113+
std::vector<int> temp_1(s.nks, px.get_local_size());
114+
psi::Psi<double> X(s.nks, nstate, px.get_local_size(), temp_1.data(), false);
109115
Parallel_2D pc;
110116
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt);
111-
psi::Psi<double> c(s.nks, pc.get_col_size(), pc.get_row_size());
117+
118+
std::vector<int> temp_2(s.nks, pc.get_row_size());
119+
psi::Psi<double> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true);
112120
Parallel_2D pmat;
113121
LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt);
114122

@@ -147,7 +155,8 @@ TEST_F(DMTransTest, DoubleParallel)
147155
LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data<double>(), dm_gather[isk].data<double>(), false, s.naos, s.naos);
148156

149157
// compare to global matrix
150-
psi::Psi<double> c_full(s.nks, s.nocc + s.nvirt, s.naos);
158+
std::vector<int> temp(s.nks, s.naos);
159+
psi::Psi<double> c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true);
151160
for (int isk = 0;isk < s.nks;++isk)
152161
{
153162
c.fix_k(isk);
@@ -173,7 +182,9 @@ TEST_F(DMTransTest, ComplexParallel)
173182
psi::Psi<std::complex<double>> X(s.nks, nstate, px.get_local_size(), nullptr, false);
174183
Parallel_2D pc;
175184
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt);
176-
psi::Psi<std::complex<double>> c(s.nks, pc.get_col_size(), pc.get_row_size());
185+
186+
std::vector<int> temp(s.nks, pc.get_row_size());
187+
psi::Psi<std::complex<double>> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true);
177188
Parallel_2D pmat;
178189
LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt);
179190

@@ -206,7 +217,8 @@ TEST_F(DMTransTest, ComplexParallel)
206217
LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data<std::complex<double>>(), dm_gather[isk].data<std::complex<double>>(), false, s.naos, s.naos);
207218

208219
// compare to global matrix
209-
psi::Psi<std::complex<double>> c_full(s.nks, s.nocc + s.nvirt, s.naos);
220+
std::vector<int> ngk_temp_2(s.nks, s.naos);
221+
psi::Psi<std::complex<double>> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true);
210222
for (int isk = 0;isk < s.nks;++isk)
211223
{
212224
c.fix_k(isk);

source/module_lr/utils/test/lr_util_algorithms_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ TEST(LR_Util, PsiWrapper)
99
int nbands = 5;
1010
int nbasis = 6;
1111

12-
psi::Psi<float> k1(1, nbands, nk * nbasis);
12+
psi::Psi<float> k1(1, nbands, nk * nbasis, nk * nbasis, true);
1313
for (int i = 0;i < nbands * nk * nbasis;++i)k1.get_pointer()[i] = i;
1414

1515
k1.fix_b(2);

source/module_psi/test/psi_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double)
9999
// cover all lines in fix_k func
100100
psi_object31->fix_k(2);
101101
EXPECT_EQ(psi_object31->get_psi_bias(), 0);
102-
psi::Psi<std::complex<double>>* psi_temp = new psi::Psi<std::complex<double>>(ink, inbands, inbasis);
102+
psi::Psi<std::complex<double>>* psi_temp = new psi::Psi<std::complex<double>>(ink, inbands, inbasis, inbasis, true);
103103
psi_temp->fix_k(0);
104104
EXPECT_EQ(psi_object31->get_current_nbas(), inbasis);
105105
delete psi_temp;

0 commit comments

Comments
 (0)