Skip to content

Commit 032dc13

Browse files
committed
update pw wf codes
1 parent 5cdafa7 commit 032dc13

17 files changed

+133
-326
lines changed

source/module_cell/unitcell.cpp

100755100644
File mode changed.

source/module_elecstate/module_charge/charge_init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout,
247247
const K_Vectors* kv = reinterpret_cast<const K_Vectors*>(klist);
248248
const int nkstot = kv->get_nkstot();
249249
const std::vector<int>& isk = kv->isk;
250-
ModuleIO::read_wf2rho_pw(pw_wfc, symm, kv->ik2iktot.data(), nkstot, isk, *this);
250+
ModuleIO::read_wf2rho_pw(pw_wfc, symm, kv->ik2iktot, nkstot, isk, *this);
251251
}
252252
}
253253

source/module_io/read_wf2rho_pw.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
#include "module_elecstate/module_charge/symmetry_rho.h"
77
#include "module_parameter/parameter.h"
88
#include "module_elecstate/kernels/elecstate_op.h"
9+
#include "module_io/filename.h"
910

1011
void ModuleIO::read_wf2rho_pw(const ModulePW::PW_Basis_K* pw_wfc,
1112
ModuleSymmetry::Symmetry& symm,
12-
const int* ik2iktot,
13+
const std::vector<int> &ik2iktot,
1314
const int nkstot,
14-
const std::vector<int>& isk,
15+
const std::vector<int> &isk,
1516
Charge& chg)
1617
{
1718
ModuleBase::TITLE("ModuleIO", "read_wf2rho_pw");
@@ -78,11 +79,20 @@ void ModuleIO::read_wf2rho_pw(const ModulePW::PW_Basis_K* pw_wfc,
7879
{
7980
is = isk[ik];
8081
}
81-
std::stringstream filename;
8282
const int ikstot = ik2iktot[ik];
83-
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ikstot + 1 << ".dat";
84-
ModuleIO::read_wfc_pw(filename.str(), pw_wfc, ik, ikstot, nkstot, wfc_tmp);
85-
if (PARAM.inp.nspin == 4)
83+
84+
// mohan add 2025-05-17
85+
// .dat file
86+
const int out_type = 2;
87+
const bool out_app_flag = false;
88+
const bool gamma_only = false;
89+
const int istep = -1;
90+
91+
std::string fn = filename_output(PARAM.globalv.global_readin_dir,"wf","pw",ik,ik2iktot,nspin,nkstot,
92+
out_type,out_app_flag,gamma_only,istep);
93+
94+
ModuleIO::read_wfc_pw(fn, pw_wfc, ik, ikstot, nkstot, wfc_tmp);
95+
if (nspin == 4)
8696
{
8797
std::vector<std::complex<double>> rho_tmp2(nrxx);
8898
for (int ib = 0; ib < nbands; ++ib)

source/module_io/read_wf2rho_pw.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_basis/module_pw/pw_basis_k.h"
55
#include "module_elecstate/module_charge/charge.h"
66

7+
#include <vector>
78
#include <string>
89

910
namespace ModuleIO
@@ -19,9 +20,9 @@ namespace ModuleIO
1920
*/
2021
void read_wf2rho_pw(const ModulePW::PW_Basis_K* pw_wfc,
2122
ModuleSymmetry::Symmetry& symm,
22-
const int* ik2iktot,
23+
const std::vector<int> &ik2iktot,
2324
const int nkstot,
24-
const std::vector<int>& isk,
25+
const std::vector<int> &isk,
2526
Charge& chg);
2627

2728
} // namespace ModuleIO

source/module_io/read_wfc_nao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ bool ModuleIO::read_wfc_nao(
6767
ifs.open(ss.c_str());
6868
if (!ifs)
6969
{
70-
error_message << " Can't open file:" << ss << std::endl;
70+
error_message << " Can't open file: " << ss << std::endl;
7171
return false;
7272
}
7373
else

source/module_io/read_wfc_nao.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_IO_READ_WFC_NAO_H
2-
#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_IO_READ_WFC_NAO_H
1+
#ifndef READ_WFC_NAO_H
2+
#define READ_WFC_NAO_H
33

44
#include "module_basis/module_ao/parallel_orbitals.h"
55
#include "module_psi/psi.h"

source/module_io/test/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ AddTest(
8585
AddTest(
8686
TARGET io_write_wfc_nao
8787
LIBS parameter ${math_libs} base psi device
88-
SOURCES write_wfc_nao_test.cpp ../write_wfc_nao.cpp ../../module_basis/module_ao/parallel_orbitals.cpp ../binstream.cpp
88+
SOURCES write_wfc_nao_test.cpp ../filename.cpp ../write_wfc_nao.cpp ../../module_basis/module_ao/parallel_orbitals.cpp ../binstream.cpp
8989
)
9090

9191
install(FILES write_wfc_nao_para.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
@@ -173,8 +173,7 @@ add_test(NAME read_wfc_pw_test_parallel
173173
AddTest(
174174
TARGET read_wf2rho_pw_test
175175
LIBS parameter base ${math_libs} device planewave psi
176-
SOURCES read_wf2rho_pw_test.cpp ../read_wfc_pw.cpp ../read_wf2rho_pw.cpp ../binstream.cpp ../../module_basis/module_pw/test/test_tool.cpp
177-
../../module_elecstate/module_charge/charge_mpi.cpp ../write_wfc_pw.cpp
176+
SOURCES read_wf2rho_pw_test.cpp ../read_wfc_pw.cpp ../read_wf2rho_pw.cpp ../binstream.cpp ../../module_basis/module_pw/test/test_tool.cpp ../../module_elecstate/module_charge/charge_mpi.cpp ../filename.cpp ../write_wfc_pw.cpp
178177
)
179178

180179
add_test(NAME read_wf2rho_pw_parallel

source/module_io/test/read_wf2rho_pw_test.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
1313
#include "module_io/read_wf2rho_pw.h"
1414
#include "module_io/write_wfc_pw.h"
15+
#include "module_io/filename.h" // mohan add 2025-05-17
1516
#include "module_parameter/parameter.h"
1617
#include "module_psi/psi.h"
1718

@@ -219,10 +220,11 @@ TEST_F(ReadWfcRhoTest, ReadWfcRho)
219220
#endif
220221

221222
// Write the wave functions to file
222-
ModuleIO::write_wfc_pw("WAVEFUNC", *psi, *kv, wfcpw);
223+
const std::string global_out_dir = "./";
224+
ModuleIO::write_wfc_pw(PARAM.input.out_wfc_pw, global_out_dir, *psi, *kv, wfcpw);
223225

224226
// Read the wave functions to charge density
225-
ModuleIO::read_wf2rho_pw(wfcpw, symm, kv->ik2iktot.data(), nkstot, kv->isk, chg);
227+
ModuleIO::read_wf2rho_pw(wfcpw, symm, kv->ik2iktot, nkstot, kv->isk, chg);
226228

227229
// compare the charge density
228230
for (int ir = 0; ir < rhopw->nrxx; ++ir)
@@ -231,22 +233,24 @@ TEST_F(ReadWfcRhoTest, ReadWfcRho)
231233
}
232234
// std::cout.precision(16);
233235
// std::cout<<chg.rho[0][0]<<std::endl;
234-
if (GlobalV::NPROC == 1) {
235-
EXPECT_NEAR(chg.rho[0][0], 8617.076357957576, 1e-8);
236-
} else if (GlobalV::NPROC == 4)
237-
{
238-
const std::vector<double> ref = {8207.849135313403, 35.34776105132742, 8207.849135313403, 35.34776105132742};
239-
EXPECT_NEAR(chg.rho[0][0], ref[GlobalV::MY_RANK], 1e-8);
240-
// for (int ip = 0; ip < GlobalV::NPROC; ++ip)
241-
// {
242-
// if (GlobalV::MY_RANK == ip)
243-
// {
244-
// std::cout.precision(16);
245-
// std::cout << GlobalV::MY_RANK << " " << chg.rho[0][0] << std::endl;
246-
// }
247-
// MPI_Barrier(MPI_COMM_WORLD);
248-
// }
249-
}
236+
if (GlobalV::NPROC == 1)
237+
{
238+
EXPECT_NEAR(chg.rho[0][0], 8617.076357957576, 1e-8);
239+
}
240+
else if (GlobalV::NPROC == 4)
241+
{
242+
const std::vector<double> ref = {8207.849135313403, 35.34776105132742, 8207.849135313403, 35.34776105132742};
243+
EXPECT_NEAR(chg.rho[0][0], ref[GlobalV::MY_RANK], 1e-8);
244+
// for (int ip = 0; ip < GlobalV::NPROC; ++ip)
245+
// {
246+
// if (GlobalV::MY_RANK == ip)
247+
// {
248+
// std::cout.precision(16);
249+
// std::cout << GlobalV::MY_RANK << " " << chg.rho[0][0] << std::endl;
250+
// }
251+
// MPI_Barrier(MPI_COMM_WORLD);
252+
// }
253+
}
250254

251255
delete[] chg.rho;
252256
delete[] chg._space_rho;

source/module_io/test/read_wfc_nao_test.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,25 @@ namespace elecstate
1414
const double* ElecState::getRho(int spin) const{return &(this->eferm.ef);}//just for mock
1515
}
1616

17-
// mock wfc_lcao_gen_fname
18-
std::string ModuleIO::wfc_nao_gen_fname(const int out_type,
19-
const bool gamma_only,
20-
const bool out_app_flag,
21-
const int ik,
22-
const std::vector<int> &ik2iktot,
23-
const int nkstot,
24-
const int nspin,
25-
const int istep)
17+
18+
namespace ModuleIO
19+
{
20+
// mock filename_output
21+
std::string filename_output(
22+
const std::string &directory,
23+
const std::string &property,
24+
const std::string &basis,
25+
const int ik,
26+
const std::vector<int> &ik2iktot,
27+
const int nspin,
28+
const int nkstot,
29+
const int out_type,
30+
const bool out_app_flag,
31+
const bool gamma_only,
32+
const int istep)
2633
{
27-
return "wfs1_nao.txt";
34+
return "./support/wfs1_nao.txt";
35+
}
2836
}
2937

3038
/************************************************

source/module_io/test/read_wfc_pw_test.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ReadWfcPwTest : public ::testing::Test
4242
// Test the read_wfc_pw function
4343
TEST_F(ReadWfcPwTest, ReadWfcPw)
4444
{
45-
std::string filename = "./support/WAVEFUNC1.dat";
45+
std::string filename = "./support/wfs1k1_pw.dat";
4646

4747
#ifdef __MPI
4848
wfcpw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
@@ -152,7 +152,7 @@ TEST_F(ReadWfcPwTest, InconsistentBands)
152152
{
153153
if (GlobalV::NPROC_IN_POOL == 1)
154154
{
155-
std::string filename = "./support/WAVEFUNC1.dat";
155+
std::string filename = "./support/wfs1k1_pw.dat";
156156

157157
#ifdef __MPI
158158
wfcpw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
@@ -182,7 +182,7 @@ TEST_F(ReadWfcPwTest, InconsistentKvec)
182182
{
183183
if (GlobalV::NPROC_IN_POOL == 1)
184184
{
185-
std::string filename = "./support/WAVEFUNC1.dat";
185+
std::string filename = "./support/wfs1k1_pw.dat";
186186

187187
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 1.0);
188188

@@ -211,7 +211,8 @@ TEST_F(ReadWfcPwTest, InconsistentLat0)
211211
{
212212
if (GlobalV::NPROC_IN_POOL == 1)
213213
{
214-
std::string filename = "./support/WAVEFUNC1.dat";
214+
std::string filename = "./support/wfs1k1_pw.dat";
215+
215216
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 0.0);
216217

217218
#ifdef __MPI
@@ -239,7 +240,7 @@ TEST_F(ReadWfcPwTest, InconsistentG)
239240
{
240241
if (GlobalV::NPROC_IN_POOL == 1)
241242
{
242-
std::string filename = "./support/WAVEFUNC1.dat";
243+
std::string filename = "./support/wfs1k1_pw.dat";
243244
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 0.0);
244245

245246
#ifdef __MPI
@@ -284,4 +285,4 @@ int main(int argc, char** argv)
284285
finishmpi();
285286
#endif
286287
return result;
287-
}
288+
}

0 commit comments

Comments
 (0)