Skip to content

Commit 912a412

Browse files
committed
update the tests for read_wfc_nao and write_wfc_nao
1 parent 616130f commit 912a412

File tree

10 files changed

+104
-47
lines changed

10 files changed

+104
-47
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,14 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
166166

167167
// 5) read psi from file
168168
if (PARAM.inp.init_wfc == "file")
169-
{
170-
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, this->pv, *(this->psi), this->pelec))
169+
{
170+
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir,
171+
this->pv,
172+
*(this->psi),
173+
this->pelec,
174+
this->pelec->klist->ik2iktot,
175+
this->pelec->klist->get_nkstot(),
176+
PARAM.inp.nspin))
171177
{
172178
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "read electronic wave functions failed");
173179
}

source/module_esolver/lcao_others.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
174174
// init wfc from file
175175
if (istep == 0 && PARAM.inp.init_wfc == "file")
176176
{
177-
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, this->pv, *(this->psi), this->pelec))
177+
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir,
178+
this->pv,
179+
*(this->psi),
180+
this->pelec,
181+
this->pelec->klist->ik2iktot,
182+
this->pelec->klist->get_nkstot(),
183+
PARAM.inp.nspin))
178184
{
179185
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::others", "read wfc nao failed");
180186
}

source/module_io/read_wfc_nao.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, double& data)
1414

1515
void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, std::complex<double>& data)
1616
{
17-
double a = 0.0, b = 0.0;
17+
double a = 0.0;
18+
double b = 0.0;
1819
ifs >> a >> b;
1920
data = std::complex<double>(a, b);
2021
}
@@ -24,15 +25,19 @@ bool ModuleIO::read_wfc_nao(
2425
const std::string& global_readin_dir,
2526
const Parallel_Orbitals& ParaV,
2627
psi::Psi<T>& psid,
27-
elecstate::ElecState* const pelec,
28+
elecstate::ElecState* const pelec,
29+
const std::vector<int> &ik2iktot,
30+
const int nkstot,
31+
const int nspin,
2832
const int skip_band)
2933
{
3034
ModuleBase::TITLE("ModuleIO", "read_wfc_nao");
3135
ModuleBase::timer::tick("ModuleIO", "read_wfc_nao");
3236

33-
int nk = pelec->ekb.nr;
34-
bool gamma_only = std::is_same<T, double>::value;
35-
int out_type = 1; // only support text file now
37+
const int nk = pelec->ekb.nr;
38+
39+
const bool gamma_only = std::is_same<T, double>::value;
40+
const int out_type = 1; // only support .txt file now
3641
bool read_success = true;
3742
int myrank = 0;
3843
int nbands = ParaV.get_wfc_global_nbands(); // the global number of bands
@@ -129,20 +134,31 @@ bool ModuleIO::read_wfc_nao(
129134
}
130135
ifs.close();
131136
return true;
132-
};
137+
}; // end read one file
133138

134139

135140
std::string errors;
136-
std::vector<T> ctot((myrank==0)?nbands * nlocal:0);
141+
142+
std::vector<T> ctot;
143+
if (myrank == 0)
144+
{
145+
ctot.resize(nbands * nlocal);
146+
}
147+
else
148+
{
149+
ctot.resize(0);
150+
}
137151

138152
for(int ik=0;ik<nk;ik++)
139153
{
140154
if (myrank == 0)
141155
{
156+
const bool out_app_flag = false;
157+
const int nstep = -1;
142158
std::stringstream error_message;
143159
std::string ss = global_readin_dir + ModuleIO::wfc_nao_gen_fname(
144-
out_type, gamma_only, false, ik,
145-
pelec->klist->ik2iktot, pelec->klist->get_nkstot(), PARAM.inp.nspin);
160+
out_type, gamma_only, out_app_flag, ik,
161+
ik2iktot, nkstot, nspin, nstep);
146162

147163
read_success = read_one_file(ss, error_message, ik, ctot);
148164
errors = error_message.str();
@@ -186,11 +202,17 @@ bool ModuleIO::read_wfc_nao(
186202
template bool ModuleIO::read_wfc_nao<double>(const std::string& global_readin_dir,
187203
const Parallel_Orbitals& ParaV,
188204
psi::Psi<double>& psid,
189-
elecstate::ElecState* const pelec,
205+
elecstate::ElecState* const pelec,
206+
const std::vector<int> &ik2iktot,
207+
const int nkstot,
208+
const int nspin,
190209
const int skip_band);
191210

192211
template bool ModuleIO::read_wfc_nao<std::complex<double>>(const std::string& global_readin_dir,
193212
const Parallel_Orbitals& ParaV,
194-
psi::Psi<std::complex<double>>& psid,
195-
elecstate::ElecState* const pelec,
196-
const int skip_band);
213+
psi::Psi<std::complex<double>>& psid,
214+
elecstate::ElecState* const pelec,
215+
const std::vector<int> &ik2iktot,
216+
const int nkstot,
217+
const int nspin,
218+
const int skip_band);

source/module_io/read_wfc_nao.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ bool read_wfc_nao(
4040
const std::string& global_readin_dir,
4141
const Parallel_Orbitals& ParaV,
4242
psi::Psi<T>& psid,
43-
elecstate::ElecState* const pelec,
43+
elecstate::ElecState* const pelec,
44+
const std::vector<int> &ik2iktot,
45+
const int nkstot,
46+
const int nspin,
4447
const int skip_band = 0);
4548

4649
} // namespace ModuleIO

source/module_io/test/read_wfc_nao_test.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ namespace elecstate
1616

1717
// mock wfc_lcao_gen_fname
1818
std::string ModuleIO::wfc_nao_gen_fname(const int out_type,
19-
const bool gamma_only,
20-
const bool out_app_flag,
21-
const int ik,
22-
const int istep)
19+
const bool gamma_only,
20+
const bool out_app_flag,
21+
const int ik,
22+
const std::vector<int> &ik2iktot,
23+
const int nkstot,
24+
const int nspin,
25+
const int istep)
2326
{
24-
return "WFC_NAO_GAMMA2.txt";
27+
return "wfs1_nao.txt";
2528
}
2629

2730
/************************************************
@@ -45,10 +48,11 @@ class ReadWfcNaoTest : public ::testing::Test
4548
TEST_F(ReadWfcNaoTest,ReadWfcNao)
4649
{
4750
//Global variables
48-
int nbands = 3;
49-
int nlocal = 3;
51+
const int nbands = 3;
52+
const int nlocal = 3;
5053
PARAM.sys.global_readin_dir = "./support/";
51-
int nks = 1;
54+
const int nks = 1;
55+
const int nspin = 1;
5256
int my_rank = 0;
5357

5458
Parallel_Orbitals ParaV;
@@ -68,8 +72,13 @@ TEST_F(ReadWfcNaoTest,ReadWfcNao)
6872
elecstate::ElecState pelec;
6973
pelec.ekb.create(nks,nbands);
7074
pelec.wg.create(nks,nbands);
75+
76+
std::vector<int> ik2iktot = {0};
77+
const int nkstot = 1;
78+
7179
// Act
72-
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, &(pelec));
80+
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid,
81+
&(pelec), ik2iktot, nkstot, nspin);
7382
// Assert
7483
EXPECT_NEAR(pelec.ekb(0,1),0.31482195194888534794941393,1e-5);
7584
EXPECT_NEAR(pelec.wg(0,1),0.0,1e-5);
@@ -78,6 +87,7 @@ TEST_F(ReadWfcNaoTest,ReadWfcNao)
7887
EXPECT_NEAR(psid(0,0,0),5.3759239842e-01,1e-5);
7988
}
8089
}
90+
8191
TEST_F(ReadWfcNaoTest, ReadWfcNaoPart)
8292
{
8393
//Global variables
@@ -86,6 +96,7 @@ TEST_F(ReadWfcNaoTest, ReadWfcNaoPart)
8696
const int nlocal = 3;
8797
PARAM.sys.global_readin_dir = "./support/";
8898
const int nks = 1;
99+
const int nspin = 1;
89100
int my_rank = 0;
90101

91102
Parallel_Orbitals ParaV;
@@ -105,8 +116,14 @@ TEST_F(ReadWfcNaoTest, ReadWfcNaoPart)
105116
elecstate::ElecState pelec;
106117
pelec.ekb.create(nks, nbands);
107118
pelec.wg.create(nks, nbands);
108-
// Act
109-
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid, &(pelec), /*skip_band=*/1);
119+
120+
std::vector<int> ik2iktot = {0};
121+
const int nkstot = 1;
122+
123+
// Act
124+
ModuleIO::read_wfc_nao(PARAM.sys.global_readin_dir, ParaV, psid,
125+
&(pelec), ik2iktot, nkstot, nspin, skip_band);
126+
110127
// Assert
111128
EXPECT_NEAR(pelec.ekb(0, 1), 7.4141254894954844445464914e-01, 1e-5);
112129
if (my_rank == 0)

source/module_io/test/write_wfc_nao_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST(GenWfcLcaoFnameTest, OutType1GammaOnlyOutAppFlagTrue)
2727
// if out_app_flag = true, then the 'g' label will not show up
2828
const int istep = 0;
2929

30-
std::string expected_output = "wfs1k1_nao.txt";
30+
std::string expected_output = "wfs1_nao.txt";
3131
std::string result = ModuleIO::wfc_nao_gen_fname(out_type, gamma_only, out_app_flag, ik,
3232
ik2iktot, nkstot, nspin, istep);
3333

source/module_io/write_wfc_nao.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ namespace ModuleIO
1414
{
1515

1616
std::string wfc_nao_gen_fname(const int out_type,
17-
const bool gamma_only,
18-
const bool out_app_flag,
19-
const int ik,
20-
const std::vector<int> &ik2iktot,
21-
const int nkstot,
22-
const int nspin,
23-
const int istep)
17+
const bool gamma_only,
18+
const bool out_app_flag,
19+
const int ik,
20+
const std::vector<int> &ik2iktot,
21+
const int nkstot,
22+
const int nspin,
23+
const int istep)
2424
{
2525
// fn_out = "{PARAM.globalv.global_out_dir}/wf{s}{spin index}{k(optinal)}{k-point index}
2626
// {g(optional)}{geometry index1}{_nao} + {".txt"/".dat"}""
@@ -88,7 +88,7 @@ std::string wfc_nao_gen_fname(const int out_type,
8888
else
8989
{
9090
std::cout << "WARNING: the type of output wave function is not 1 or 2, so 1 is chosen." << std::endl;
91-
suffix_block = ".txt";
91+
suffix_block = "_nao.txt";
9292
}
9393

9494
std::string fn_out

source/module_io/write_wfc_nao.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ namespace ModuleIO
2020
* @return The generated filename.
2121
*/
2222
std::string wfc_nao_gen_fname(const int out_type,
23-
const bool gamma_only,
24-
const bool out_app_flag,
25-
const int ik,
26-
const std::vector<int> &ik2iktot,
27-
const int nkstot,
28-
const int nspin,
29-
const int istep=-1);
23+
const bool gamma_only,
24+
const bool out_app_flag,
25+
const int ik,
26+
const std::vector<int> &ik2iktot,
27+
const int nkstot,
28+
const int nspin,
29+
const int istep=-1);
3030

31-
/**
31+
/**
3232
* Writes the wavefunction coefficients for the LCAO method to a file.
3333
* Will loop all k-points by psi.get_nk().
3434
* The nbands are determined by ekb.nc.

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,11 @@ void LR::ESolver_LR<T, TR>::read_ks_wfc()
725725
ModuleBase::WARNING_QUIT("ESolver_LR", "RI benchmark is only supported when compile with LibRI.");
726726
#endif
727727
}
728-
else if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, this->paraMat_, *this->psi_ks, this->pelec,
729-
/*skip_bands=*/this->nocc_max - this->nocc_in)) {
728+
else if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, this->paraMat_, *this->psi_ks,
729+
this->pelec,
730+
this->pelec->klist->ik2iktot,
731+
this->pelec->klist->get_nkstot(),
732+
/*skip_bands=*/this->nocc_max - this->nocc_in)) {
730733
ModuleBase::WARNING_QUIT("ESolver_LR", "read ground-state wavefunction failed.");
731734
}
732735
this->eig_ks = std::move(this->pelec->ekb);

0 commit comments

Comments
 (0)