Skip to content

Commit 5735afb

Browse files
committed
fix UT bugs
1 parent 05fe897 commit 5735afb

File tree

6 files changed

+76
-68
lines changed

6 files changed

+76
-68
lines changed

source/module_io/test/read_wfc_nao_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ TEST_F(ReadWfcNaoTest,ReadWfcNao)
6969
pelec.ekb.create(nks,nbands);
7070
pelec.wg.create(nks,nbands);
7171
// Act
72-
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, &(pelec));
72+
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, pelec.wg, pelec.ekb);
7373
// Assert
7474
EXPECT_NEAR(pelec.ekb(0,1),0.31482195194888534794941393,1e-5);
7575
EXPECT_NEAR(pelec.wg(0,1),0.0,1e-5);
@@ -106,7 +106,7 @@ TEST_F(ReadWfcNaoTest, ReadWfcNaoPart)
106106
pelec.ekb.create(nks, nbands);
107107
pelec.wg.create(nks, nbands);
108108
// Act
109-
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, &(pelec), /*skip_band=*/1);
109+
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, pelec.wg, pelec.ekb, /*skip_band=*/1);
110110
// Assert
111111
EXPECT_NEAR(pelec.ekb(0, 1), 7.4141254894954844445464914e-01, 1e-5);
112112
if (my_rank == 0)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
remove_definitions(-DUSE_LIBXC)
22
AddTest(
33
TARGET CVCX_test
4-
LIBS base ${math_libs} container device psi
4+
LIBS base parameter ${math_libs} container device psi
55
SOURCES CVCX_test.cpp ../../../utils/lr_util.cpp ../CVCX_parallel.cpp ../CVCX_serial.cpp
66
)

source/module_lr/Grad/CVCX/test/CVCX_test.cpp

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class AXTest : public testing::Test
2121
{
2222
public:
2323
std::vector<matsize> sizes{
24-
// {2, 3, 2, 1},
24+
// {2, 3, 2, 1}
2525
{2, 13, 7, 4},
2626
{2, 14, 8, 5}
2727
};
@@ -61,32 +61,32 @@ TEST_F(AXTest, DoubleSerial)
6161
{
6262
for (auto s : this->sizes)
6363
{
64-
psi::Psi<double, base_device::DEVICE_CPU> X(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
65-
psi::Psi<double, base_device::DEVICE_CPU> AX_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
66-
psi::Psi<double, base_device::DEVICE_CPU> AX_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
64+
psi::Psi<double, base_device::DEVICE_CPU> X(s.nks, nstate, s.nocc * s.nvirt, {}, false);
65+
psi::Psi<double, base_device::DEVICE_CPU> AX_for(s.nks, nstate, s.nocc * s.nvirt, {}, false);
66+
psi::Psi<double, base_device::DEVICE_CPU> AX_blas(s.nks, nstate, s.nocc * s.nvirt, {}, false);
6767
const int size_x = nstate * s.nks * s.nocc * s.nvirt;
6868
set_rand(X.get_pointer(), size_x);
6969

7070
const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos;
7171
const int size_v = s.naos * s.naos;
7272
for (int istate = 0;istate < nstate;++istate)
7373
{
74-
psi::Psi<double, base_device::DEVICE_CPU> c(s.nks, s.nocc + s.nvirt, s.naos);
74+
psi::Psi<double, base_device::DEVICE_CPU> c(s.nks, s.nocc + s.nvirt, s.naos, {}, true);
7575
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
7676
set_rand(c.get_pointer(), size_c);
7777
for (auto& v : V)set_rand(v.data<double>(), size_v);
7878
X.fix_b(istate);
7979
AX_for.fix_b(istate);
8080
AX_blas.fix_b(istate);
8181
// occ
82-
CVCX_occ_forloop_serial(V, c, X, s.naos, s.nocc, s.nvirt, AX_for);
83-
CVCX_occ_blas(V, c, X, s.naos, s.nocc, s.nvirt, AX_blas, false);
82+
LR::CVCX_occ_forloop_serial(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_for.get_pointer());
83+
LR::CVCX_occ_blas(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_blas.get_pointer(), false);
8484
AX_for.fix_k(0);
8585
AX_blas.fix_k(0);
8686
check_eq(AX_for.get_pointer(), AX_blas.get_pointer(), s.nks * s.nocc * s.nvirt);
8787
// virt
88-
CVCX_virt_forloop_serial(V, c, X, s.naos, s.nocc, s.nvirt, AX_for);
89-
CVCX_virt_blas(V, c, X, s.naos, s.nocc, s.nvirt, AX_blas, false);
88+
LR::CVCX_virt_forloop_serial(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_for.get_pointer());
89+
LR::CVCX_virt_blas(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_blas.get_pointer(), false);
9090
AX_for.fix_k(0);
9191
AX_blas.fix_k(0);
9292
check_eq(AX_for.get_pointer(), AX_blas.get_pointer(), s.nks * s.nocc * s.nvirt);
@@ -98,32 +98,32 @@ TEST_F(AXTest, ComplexSerial)
9898
{
9999
for (auto s : this->sizes)
100100
{
101-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
102-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
103-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
101+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X(s.nks, nstate, s.nocc * s.nvirt, {}, false);
102+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_for(s.nks, nstate, s.nocc * s.nvirt, {}, false);
103+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_blas(s.nks, nstate, s.nocc * s.nvirt, {}, false);
104104
const int size_x = nstate * s.nks * s.nocc * s.nvirt;
105105
set_rand(X.get_pointer(), size_x);
106106

107107
int size_c = s.nks * (s.nocc + s.nvirt) * s.naos;
108108
int size_v = s.naos * s.naos;
109109
for (int istate = 0;istate < nstate;++istate)
110110
{
111-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c(s.nks, s.nocc + s.nvirt, s.naos);
111+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c(s.nks, s.nocc + s.nvirt, s.naos, {}, true);
112112
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
113113
set_rand(c.get_pointer(), size_c);
114114
for (auto& v : V)set_rand(v.data<std::complex<double>>(), size_v);
115115
X.fix_b(istate);
116116
AX_for.fix_b(istate);
117117
AX_blas.fix_b(istate);
118118
// occ
119-
CVCX_occ_forloop_serial(V, c, X, s.naos, s.nocc, s.nvirt, AX_for);
120-
CVCX_occ_blas(V, c, X, s.naos, s.nocc, s.nvirt, AX_blas, false);
119+
LR::CVCX_occ_forloop_serial(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_for.get_pointer());
120+
LR::CVCX_occ_blas(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_blas.get_pointer(), false);
121121
AX_for.fix_k(0);
122122
AX_blas.fix_k(0);
123123
check_eq(AX_for.get_pointer(), AX_blas.get_pointer(), s.nks * s.nocc * s.nvirt);
124124
// virt
125-
CVCX_virt_forloop_serial(V, c, X, s.naos, s.nocc, s.nvirt, AX_for);
126-
CVCX_virt_blas(V, c, X, s.naos, s.nocc, s.nvirt, AX_blas, false);
125+
LR::CVCX_virt_forloop_serial(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_for.get_pointer());
126+
LR::CVCX_virt_blas(V, c, X.get_pointer(), s.naos, s.nocc, s.nvirt, AX_blas.get_pointer(), false);
127127
AX_for.fix_k(0);
128128
AX_blas.fix_k(0);
129129
check_eq(AX_for.get_pointer(), AX_blas.get_pointer(), s.nks * s.nocc * s.nvirt);
@@ -142,7 +142,7 @@ TEST_F(AXTest, DoubleParallel)
142142
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() }));
143143
Parallel_2D pc;
144144
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt);
145-
psi::Psi<double, base_device::DEVICE_CPU> c(s.nks, pc.get_col_size(), pc.get_row_size());
145+
psi::Psi<double, base_device::DEVICE_CPU> c(s.nks, pc.get_col_size(), pc.get_row_size(), {}, true);
146146
Parallel_2D px;
147147
LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt);
148148

@@ -152,13 +152,13 @@ TEST_F(AXTest, DoubleParallel)
152152
EXPECT_GE(s.nocc, px.dim1);
153153
EXPECT_GE(s.naos, pc.dim0);
154154

155-
psi::Psi<double, base_device::DEVICE_CPU> AX_pblas_loc(s.nks, nstate, px.get_local_size());
156-
psi::Psi<double, base_device::DEVICE_CPU> AX_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
155+
psi::Psi<double, base_device::DEVICE_CPU> AX_pblas_loc(s.nks, nstate, px.get_local_size(), {}, false);
156+
psi::Psi<double, base_device::DEVICE_CPU> AX_gather(s.nks, nstate, s.nocc * s.nvirt, {}, false);
157157

158158
//set X and X_full
159-
psi::Psi<double, base_device::DEVICE_CPU> X(s.nks, nstate, px.get_local_size(), nullptr, false);
159+
psi::Psi<double, base_device::DEVICE_CPU> X(s.nks, nstate, px.get_local_size(), {}, false);
160160
set_rand(X.get_pointer(), nstate * s.nks * px.get_local_size());
161-
psi::Psi<double, base_device::DEVICE_CPU> X_full(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full
161+
psi::Psi<double, base_device::DEVICE_CPU> X_full(s.nks, nstate, s.nocc * s.nvirt, {}, false); // allocate X_full
162162
for (int istate = 0;istate < nstate;++istate)
163163
{
164164
X.fix_b(istate);
@@ -183,7 +183,7 @@ TEST_F(AXTest, DoubleParallel)
183183
X_full.fix_b(istate);
184184
AX_pblas_loc.fix_b(istate);
185185
AX_gather.fix_b(istate);
186-
CVCX_occ_pblas(V, pV, c, pc, X, px, s.naos, s.nocc, s.nvirt, AX_pblas_loc, false);
186+
LR::CVCX_occ_pblas(V, pV, c, pc, X.get_pointer(), px, s.naos, s.nocc, s.nvirt, AX_pblas_loc.get_pointer(), false);
187187
// gather AX and output
188188
for (int isk = 0;isk < s.nks;++isk)
189189
{
@@ -193,7 +193,7 @@ TEST_F(AXTest, DoubleParallel)
193193
}
194194
// compare to global AX
195195
std::vector<container::Tensor> V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
196-
psi::Psi<double, base_device::DEVICE_CPU> c_full(s.nks, s.nocc + s.nvirt, s.naos);
196+
psi::Psi<double, base_device::DEVICE_CPU> c_full(s.nks, s.nocc + s.nvirt, s.naos, {}, true);
197197
for (int isk = 0;isk < s.nks;++isk)
198198
{
199199
LR_Util::gather_2d_to_full(pV, V.at(isk).data<double>(), V_full.at(isk).data<double>());
@@ -203,15 +203,19 @@ TEST_F(AXTest, DoubleParallel)
203203
}
204204
if (my_rank == 0)
205205
{
206-
psi::Psi<double, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false);
207-
CVCX_occ_blas(V_full, c_full, X_full, s.naos, s.nocc, s.nvirt, AX_full_istate, false);
206+
psi::Psi<double, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, {}, true);
207+
LR::CVCX_occ_blas(V_full, c_full, X_full.get_pointer(), s.naos, s.nocc, s.nvirt, AX_full_istate.get_pointer(), false);
208208
AX_full_istate.fix_b(0);
209209
AX_gather.fix_b(istate);
210210
check_eq(AX_full_istate.get_pointer(), AX_gather.get_pointer(), s.nks * s.nocc * s.nvirt);
211211
}
212212

213213
// //============ the same for virtual ==========
214-
CVCX_virt_pblas(V, pV, c, pc, X, px, s.naos, s.nocc, s.nvirt, AX_pblas_loc, false);
214+
X.fix_b(istate);
215+
X_full.fix_b(istate);
216+
AX_pblas_loc.fix_b(istate);
217+
AX_gather.fix_b(istate);
218+
LR::CVCX_virt_pblas(V, pV, c, pc, X.get_pointer(), px, s.naos, s.nocc, s.nvirt, AX_pblas_loc.get_pointer(), false);
215219
for (int isk = 0;isk < s.nks;++isk)
216220
{
217221
AX_pblas_loc.fix_k(isk);
@@ -220,8 +224,8 @@ TEST_F(AXTest, DoubleParallel)
220224
}
221225
if (my_rank == 0)
222226
{
223-
psi::Psi<double, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false);
224-
CVCX_virt_blas(V_full, c_full, X_full, s.naos, s.nocc, s.nvirt, AX_full_istate, false);
227+
psi::Psi<double, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, {}, true);
228+
LR::CVCX_virt_blas(V_full, c_full, X_full.get_pointer(), s.naos, s.nocc, s.nvirt, AX_full_istate.get_pointer(), false);
225229
AX_full_istate.fix_b(0);
226230
AX_gather.fix_b(istate);
227231
check_eq(AX_full_istate.get_pointer(), AX_gather.get_pointer(), s.nks * s.nocc * s.nvirt);
@@ -240,17 +244,17 @@ TEST_F(AXTest, ComplexParallel)
240244
std::vector<container::Tensor> V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() }));
241245
Parallel_2D pc;
242246
LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt);
243-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c(s.nks, pc.get_col_size(), pc.get_row_size());
247+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c(s.nks, pc.get_col_size(), pc.get_row_size(), {}, true);
244248
Parallel_2D px;
245249
LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt);
246250

247-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_pblas_loc(s.nks, nstate, px.get_local_size());
248-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false);
251+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_pblas_loc(s.nks, nstate, px.get_local_size(), {}, false);
252+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_gather(s.nks, nstate, s.nocc * s.nvirt, {}, false);
249253

250254
//set X and X_full
251-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X(s.nks, nstate, px.get_local_size(), nullptr, false);
255+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X(s.nks, nstate, px.get_local_size(), {}, false);
252256
set_rand(X.get_pointer(), nstate * s.nks * px.get_local_size());
253-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X_full(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full
257+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> X_full(s.nks, nstate, s.nocc * s.nvirt, {}, false); // allocate X_full
254258
for (int istate = 0;istate < nstate;++istate)
255259
{
256260
X.fix_b(istate);
@@ -275,7 +279,7 @@ TEST_F(AXTest, ComplexParallel)
275279
X_full.fix_b(istate);
276280
AX_pblas_loc.fix_b(istate);
277281
AX_gather.fix_b(istate);
278-
CVCX_occ_pblas(V, pV, c, pc, X, px, s.naos, s.nocc, s.nvirt, AX_pblas_loc, false);
282+
LR::CVCX_occ_pblas(V, pV, c, pc, X.get_pointer(), px, s.naos, s.nocc, s.nvirt, AX_pblas_loc.get_pointer(), false);
279283

280284
// gather AX and output
281285
for (int isk = 0;isk < s.nks;++isk)
@@ -286,7 +290,7 @@ TEST_F(AXTest, ComplexParallel)
286290
}
287291
// compare to global AX
288292
std::vector<container::Tensor> V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos }));
289-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c_full(s.nks, s.nocc + s.nvirt, s.naos);
293+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> c_full(s.nks, s.nocc + s.nvirt, s.naos, {}, true);
290294
for (int isk = 0;isk < s.nks;++isk)
291295
{
292296
LR_Util::gather_2d_to_full(pV, V.at(isk).data<std::complex<double>>(), V_full.at(isk).data<std::complex<double>>());
@@ -296,14 +300,18 @@ TEST_F(AXTest, ComplexParallel)
296300
}
297301
if (my_rank == 0)
298302
{
299-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false);
300-
CVCX_occ_blas(V_full, c_full, X_full, s.naos, s.nocc, s.nvirt, AX_full_istate, false);
303+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, {}, false);
304+
LR::CVCX_occ_blas(V_full, c_full, X_full.get_pointer(), s.naos, s.nocc, s.nvirt, AX_full_istate.get_pointer(), false);
301305
AX_full_istate.fix_b(0);
302306
AX_gather.fix_b(istate);
303307
check_eq(AX_full_istate.get_pointer(), AX_gather.get_pointer(), s.nks * s.nocc * s.nvirt);
304308
}
305309
// //============ the same for virtual ==========
306-
CVCX_virt_pblas(V, pV, c, pc, X, px, s.naos, s.nocc, s.nvirt, AX_pblas_loc, false);
310+
X.fix_b(istate);
311+
X_full.fix_b(istate);
312+
AX_pblas_loc.fix_b(istate);
313+
AX_gather.fix_b(istate);
314+
LR::CVCX_virt_pblas(V, pV, c, pc, X.get_pointer(), px, s.naos, s.nocc, s.nvirt, AX_pblas_loc.get_pointer(), false);
307315
for (int isk = 0;isk < s.nks;++isk)
308316
{
309317
AX_pblas_loc.fix_k(isk);
@@ -312,8 +320,8 @@ TEST_F(AXTest, ComplexParallel)
312320
}
313321
if (my_rank == 0)
314322
{
315-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false);
316-
CVCX_virt_blas(V_full, c_full, X_full, s.naos, s.nocc, s.nvirt, AX_full_istate, false);
323+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> AX_full_istate(s.nks, 1, s.nocc * s.nvirt, {}, false);
324+
LR::CVCX_virt_blas(V_full, c_full, X_full.get_pointer(), s.naos, s.nocc, s.nvirt, AX_full_istate.get_pointer(), false);
317325
AX_full_istate.fix_b(0);
318326
AX_gather.fix_b(istate);
319327
check_eq(AX_full_istate.get_pointer(), AX_gather.get_pointer(), s.nks * s.nocc * s.nvirt);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
remove_definitions(-DUSE_LIBXC)
22
AddTest(
33
TARGET dm_diff_test
4-
LIBS psi base ${math_libs} device container
4+
LIBS psi base parameter ${math_libs} device container
55
SOURCES dm_diff_test.cpp ../../../utils/lr_util.cpp
66
)

0 commit comments

Comments
 (0)