Skip to content

Commit 11e4bb4

Browse files
committed
update the output formats of pw wave functions
1 parent b144abf commit 11e4bb4

14 files changed

+185
-252
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
734734
}
735735

736736

737-
// tmp 2025-05-17, mohan note
738-
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
739-
{
740-
std::stringstream ssw;
741-
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
742-
ModuleIO::write_wfc_pw(ssw.str(), this->psi[0], this->kv, this->pw_wfc);
743-
}
744-
737+
// tmp 2025-05-17, mohan note
738+
ModuleIO::write_wfc_pw(PARAM.inp.out_wfc_pw, PARAM.globalv.global_out_dir,
739+
this->psi[0], this->kv, this->pw_wfc);
745740

746741
//------------------------------------------------------------------
747742
//! 5) calculate Wannier functions in pw basis

source/module_io/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ list(APPEND objects
3333
para_json.cpp
3434
parse_args.cpp
3535
orb_io.cpp
36+
filename.cpp
3637
)
3738

3839
list(APPEND objects_advanced

source/module_io/get_wf_lcao.cpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
221221
}
222222
}
223223

224-
if (out_wfc_pw)
225-
{
226-
std::stringstream ssw;
227-
ssw << global_out_dir << "WAVEFUNC";
228-
std::cout << " Write G-space wave functions into \"" << global_out_dir << "/" << ssw.str() << "\" files."
229-
<< std::endl;
230-
ModuleIO::write_wfc_pw(ssw.str(), psi_g, kv, pw_wfc);
231-
}
232-
233-
// if (out_wfc_r)
234-
// {
235-
// ModuleIO::write_psi_r_1(ucell, psi_g, pw_wfc, "wfc_realspace", false, kv);
236-
// }
224+
ModuleIO::write_wfc_pw(out_wfc_pw,global_out_dir,psi_g, kv, pw_wfc);
237225

238226
for (int is = 0; is < nspin; ++is)
239227
{
@@ -381,18 +369,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
381369
}
382370
}
383371

384-
if (out_wf)
385-
{
386-
std::stringstream ssw;
387-
ssw << global_out_dir << "WAVEFUNC";
388-
std::cout << " write G-space wave functions into \"" << global_out_dir << "/" << ssw.str() << "\" files."
389-
<< std::endl;
390-
ModuleIO::write_wfc_pw(ssw.str(), psi_g, kv, pw_wfc);
391-
}
392-
// if (out_wf_r)
393-
// {
394-
// ModuleIO::write_psi_r_1(ucell, psi_g, pw_wfc, "wfc_realspace", false, kv);
395-
// }
372+
ModuleIO::write_wfc_pw(out_wf,global_out_dir,psi_g, kv, pw_wfc);
396373

397374
std::cout << " Outputting real-space wave functions in cube format..." << std::endl;
398375

source/module_io/read_wfc_nao.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "write_wfc_nao.h"
88
#include "module_base/scalapack_connector.h"
9+
#include "module_io/filename.h"
910

1011
void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, double& data)
1112
{
@@ -154,11 +155,11 @@ bool ModuleIO::read_wfc_nao(
154155
if (myrank == 0)
155156
{
156157
const bool out_app_flag = false;
157-
const int nstep = -1;
158+
const int istep = -1;
158159
std::stringstream error_message;
159-
std::string ss = global_readin_dir + ModuleIO::wfc_nao_gen_fname(
160-
out_type, gamma_only, out_app_flag, ik,
161-
ik2iktot, nkstot, nspin, nstep);
160+
161+
std::string ss = ModuleIO::filename_output(global_readin_dir,"wf","nao",
162+
ik,ik2iktot,nspin,nkstot,out_type,out_app_flag,gamma_only,istep);
162163

163164
read_success = read_one_file(ss, error_message, ik, ctot);
164165
errors = error_message.str();

source/module_io/test/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ AddTest(
8585
AddTest(
8686
TARGET MODULE_IO_write_wfc_nao
8787
LIBS parameter ${math_libs} base psi device
88-
SOURCES write_wfc_nao_test.cpp ../write_wfc_nao.cpp ../../module_basis/module_ao/parallel_orbitals.cpp ../binstream.cpp
88+
SOURCES write_wfc_nao_test.cpp ../filename.cpp ../write_wfc_nao.cpp ../../module_basis/module_ao/parallel_orbitals.cpp ../binstream.cpp
8989
)
9090

9191
install(FILES write_wfc_nao_para.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
@@ -173,8 +173,7 @@ add_test(NAME MODULE_IO_read_wfc_pw_test_parallel
173173
AddTest(
174174
TARGET MODULE_IO_read_wf2rho_pw_test
175175
LIBS parameter base ${math_libs} device planewave psi
176-
SOURCES read_wf2rho_pw_test.cpp ../read_wfc_pw.cpp ../read_wf2rho_pw.cpp ../binstream.cpp ../../module_basis/module_pw/test/test_tool.cpp
177-
../../module_elecstate/module_charge/charge_mpi.cpp ../write_wfc_pw.cpp
176+
SOURCES read_wf2rho_pw_test.cpp ../read_wfc_pw.cpp ../read_wf2rho_pw.cpp ../binstream.cpp ../../module_basis/module_pw/test/test_tool.cpp ../../module_elecstate/module_charge/charge_mpi.cpp ../filename.cpp ../write_wfc_pw.cpp
178177
)
179178

180179
add_test(NAME MODULE_IO_read_wf2rho_pw_parallel

source/module_io/test/read_input_ptest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,4 @@ int main(int argc, char** argv)
479479
MPI_Finalize();
480480
return result;
481481
}
482-
// #endif
482+
// #endif

source/module_io/test/read_wf2rho_pw_test.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
1313
#include "module_io/read_wf2rho_pw.h"
1414
#include "module_io/write_wfc_pw.h"
15+
#include "module_io/filename.h" // mohan add 2025-05-17
1516
#include "module_parameter/parameter.h"
1617
#include "module_psi/psi.h"
1718

@@ -219,10 +220,11 @@ TEST_F(ReadWfcRhoTest, ReadWfcRho)
219220
#endif
220221

221222
// Write the wave functions to file
222-
ModuleIO::write_wfc_pw("WAVEFUNC", *psi, *kv, wfcpw);
223+
const std::string global_out_dir = "./";
224+
ModuleIO::write_wfc_pw(PARAM.input.out_wfc_pw, global_out_dir, *psi, *kv, wfcpw);
223225

224226
// Read the wave functions to charge density
225-
std::ofstream running_log("running_log.txt");
227+
std::ofstream running_log("running_log.txt");
226228
ModuleIO::read_wf2rho_pw(wfcpw, symm, kv->ik2iktot, nkstot, kv->isk, chg, running_log);
227229

228230
// compare the charge density
@@ -232,22 +234,24 @@ TEST_F(ReadWfcRhoTest, ReadWfcRho)
232234
}
233235
// std::cout.precision(16);
234236
// std::cout<<chg.rho[0][0]<<std::endl;
235-
if (GlobalV::NPROC == 1) {
236-
EXPECT_NEAR(chg.rho[0][0], 8617.076357957576, 1e-8);
237-
} else if (GlobalV::NPROC == 4)
238-
{
239-
const std::vector<double> ref = {8207.849135313403, 35.34776105132742, 8207.849135313403, 35.34776105132742};
240-
EXPECT_NEAR(chg.rho[0][0], ref[GlobalV::MY_RANK], 1e-8);
241-
// for (int ip = 0; ip < GlobalV::NPROC; ++ip)
242-
// {
243-
// if (GlobalV::MY_RANK == ip)
244-
// {
245-
// std::cout.precision(16);
246-
// std::cout << GlobalV::MY_RANK << " " << chg.rho[0][0] << std::endl;
247-
// }
248-
// MPI_Barrier(MPI_COMM_WORLD);
249-
// }
250-
}
237+
if (GlobalV::NPROC == 1)
238+
{
239+
EXPECT_NEAR(chg.rho[0][0], 8617.076357957576, 1e-8);
240+
}
241+
else if (GlobalV::NPROC == 4)
242+
{
243+
const std::vector<double> ref = {8207.849135313403, 35.34776105132742, 8207.849135313403, 35.34776105132742};
244+
EXPECT_NEAR(chg.rho[0][0], ref[GlobalV::MY_RANK], 1e-8);
245+
// for (int ip = 0; ip < GlobalV::NPROC; ++ip)
246+
// {
247+
// if (GlobalV::MY_RANK == ip)
248+
// {
249+
// std::cout.precision(16);
250+
// std::cout << GlobalV::MY_RANK << " " << chg.rho[0][0] << std::endl;
251+
// }
252+
// MPI_Barrier(MPI_COMM_WORLD);
253+
// }
254+
}
251255

252256
delete[] chg.rho;
253257
delete[] chg._space_rho;

source/module_io/test/read_wfc_nao_test.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,25 @@ namespace elecstate
1414
const double* ElecState::getRho(int spin) const{return &(this->eferm.ef);}//just for mock
1515
}
1616

17-
// mock wfc_lcao_gen_fname
18-
std::string ModuleIO::wfc_nao_gen_fname(const int out_type,
19-
const bool gamma_only,
20-
const bool out_app_flag,
21-
const int ik,
22-
const std::vector<int> &ik2iktot,
23-
const int nkstot,
24-
const int nspin,
25-
const int istep)
17+
18+
namespace ModuleIO
19+
{
20+
// mock filename_output
21+
std::string filename_output(
22+
const std::string &directory,
23+
const std::string &property,
24+
const std::string &basis,
25+
const int ik,
26+
const std::vector<int> &ik2iktot,
27+
const int nspin,
28+
const int nkstot,
29+
const int out_type,
30+
const bool out_app_flag,
31+
const bool gamma_only,
32+
const int istep)
2633
{
27-
return "wfs1_nao.txt";
34+
return "./support/wfs1_nao.txt";
35+
}
2836
}
2937

3038
/************************************************

source/module_io/test/read_wfc_pw_test.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ReadWfcPwTest : public ::testing::Test
4242
// Test the read_wfc_pw function
4343
TEST_F(ReadWfcPwTest, ReadWfcPw)
4444
{
45-
std::string filename = "./support/WAVEFUNC1.dat";
45+
std::string filename = "./support/wfs1k1_pw.dat";
4646

4747
#ifdef __MPI
4848
wfcpw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
@@ -152,7 +152,7 @@ TEST_F(ReadWfcPwTest, InconsistentBands)
152152
{
153153
if (GlobalV::NPROC_IN_POOL == 1)
154154
{
155-
std::string filename = "./support/WAVEFUNC1.dat";
155+
std::string filename = "./support/wfs1k1_pw.dat";
156156

157157
#ifdef __MPI
158158
wfcpw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
@@ -182,7 +182,7 @@ TEST_F(ReadWfcPwTest, InconsistentKvec)
182182
{
183183
if (GlobalV::NPROC_IN_POOL == 1)
184184
{
185-
std::string filename = "./support/WAVEFUNC1.dat";
185+
std::string filename = "./support/wfs1k1_pw.dat";
186186

187187
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 1.0);
188188

@@ -211,7 +211,8 @@ TEST_F(ReadWfcPwTest, InconsistentLat0)
211211
{
212212
if (GlobalV::NPROC_IN_POOL == 1)
213213
{
214-
std::string filename = "./support/WAVEFUNC1.dat";
214+
std::string filename = "./support/wfs1k1_pw.dat";
215+
215216
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 0.0);
216217

217218
#ifdef __MPI
@@ -239,7 +240,7 @@ TEST_F(ReadWfcPwTest, InconsistentG)
239240
{
240241
if (GlobalV::NPROC_IN_POOL == 1)
241242
{
242-
std::string filename = "./support/WAVEFUNC1.dat";
243+
std::string filename = "./support/wfs1k1_pw.dat";
243244
kvec_d[0] = ModuleBase::Vector3<double>(0.0, 0.0, 0.0);
244245

245246
#ifdef __MPI
@@ -284,4 +285,4 @@ int main(int argc, char** argv)
284285
finishmpi();
285286
#endif
286287
return result;
287-
}
288+
}

0 commit comments

Comments
 (0)