Skip to content

Commit a0da04e

Browse files
committed
fix type conversion and file open check
1 parent d1b4fc2 commit a0da04e

2 files changed

Lines changed: 85 additions & 39 deletions

File tree

source/source_cell/read_pp_upf201.cpp

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
#include "read_pp.h"
22

3+
// convert helper function
4+
template<typename T>
5+
T safe_convert(const std::string& str, T default_value = T{}) {
6+
try {
7+
if constexpr (std::is_same_v<T, int>) {
8+
return std::stoi(str);
9+
} else if constexpr (std::is_same_v<T, double>) {
10+
return std::stod(str);
11+
} else if constexpr (std::is_same_v<T, float>) {
12+
return std::stof(str);
13+
}
14+
} catch (const std::invalid_argument& e) {
15+
std::cerr << "Warning: Invalid number format '" << str << "', using default" << std::endl;
16+
return default_value;
17+
} catch (const std::out_of_range& e) {
18+
std::cerr << "Warning: Number out of range '" << str << "', using default" << std::endl;
19+
return default_value;
20+
}
21+
return default_value;
22+
}
23+
324
// qianrui rewrite it 2021-5-10
425
// liuyu update 2023-09-17 add uspp support
526
int Pseudopot_upf::read_pseudo_upf201(std::ifstream &ifs, Atom_pseudo& pp)
@@ -303,35 +324,35 @@ void Pseudopot_upf::read_pseudo_upf201_header(std::ifstream& ifs, Atom_pseudo& p
303324
}
304325
else if (name[ip] == "z_valence")
305326
{
306-
pp.zv = std::stod(val[ip]);
327+
pp.zv = safe_convert<double>(val[ip], 0.0);
307328
}
308329
else if (name[ip] == "total_psenergy")
309330
{
310-
pp.etotps = atof(val[ip].c_str());
331+
pp.etotps = safe_convert<float>(val[ip], 0.0);
311332
}
312333
else if (name[ip] == "wfc_cutoff")
313334
{
314-
pp.ecutwfc = atof(val[ip].c_str());
335+
pp.ecutwfc = safe_convert<float>(val[ip], 0.0);
315336
}
316337
else if (name[ip] == "rho_cutoff")
317338
{
318-
pp.ecutrho = atof(val[ip].c_str());
339+
pp.ecutrho = safe_convert<float>(val[ip], 0.0);
319340
}
320341
else if (name[ip] == "l_max")
321342
{
322-
pp.lmax = atoi(val[ip].c_str());
343+
pp.lmax = safe_convert<int>(val[ip], 0);
323344
}
324345
else if (name[ip] == "l_max_rho")
325346
{
326-
this->lmax_rho = atoi(val[ip].c_str());
347+
this->lmax_rho = safe_convert<int>(val[ip], 0);
327348
}
328349
else if (name[ip] == "l_local")
329350
{
330-
this->lloc = atoi(val[ip].c_str());
351+
this->lloc = safe_convert<int>(val[ip], 0);
331352
}
332353
else if (name[ip] == "mesh_size")
333354
{
334-
pp.mesh = atoi(val[ip].c_str());
355+
pp.mesh = safe_convert<int>(val[ip], 0);
335356
this->mesh_changed = false;
336357
if (pp.mesh % 2 == 0)
337358
{
@@ -341,11 +362,11 @@ void Pseudopot_upf::read_pseudo_upf201_header(std::ifstream& ifs, Atom_pseudo& p
341362
}
342363
else if (name[ip] == "number_of_wfc")
343364
{
344-
pp.nchi = atoi(val[ip].c_str());
365+
pp.nchi = safe_convert<int>(val[ip], 0);
345366
}
346367
else if (name[ip] == "number_of_proj")
347368
{
348-
pp.nbeta = atoi(val[ip].c_str());
369+
pp.nbeta = safe_convert<int>(val[ip], 0);
349370
}
350371
else
351372
{
@@ -377,11 +398,11 @@ void Pseudopot_upf::read_pseudo_upf201_mesh(std::ifstream& ifs, Atom_pseudo& pp)
377398
{
378399
if (name[ip] == "dx")
379400
{
380-
dx = atof(val[ip].c_str());
401+
dx = safe_convert<float>(val[ip], 0.0);
381402
}
382403
else if (name[ip] == "mesh")
383404
{
384-
pp.mesh = atoi(val[ip].c_str());
405+
pp.mesh = safe_convert<int>(val[ip], 0);
385406

386407
this->mesh_changed = false;
387408
if (pp.mesh % 2 == 0)
@@ -392,15 +413,15 @@ void Pseudopot_upf::read_pseudo_upf201_mesh(std::ifstream& ifs, Atom_pseudo& pp)
392413
}
393414
else if (name[ip] == "xmin")
394415
{
395-
xmin = atof(val[ip].c_str());
416+
xmin = safe_convert<float>(val[ip], 0.0);
396417
}
397418
else if (name[ip] == "rmax")
398419
{
399-
rmax = atof(val[ip].c_str());
420+
rmax = safe_convert<float>(val[ip], 0.0);
400421
}
401422
else if (name[ip] == "zmesh")
402423
{
403-
zmesh = atof(val[ip].c_str());
424+
zmesh = safe_convert<float>(val[ip], 0.0);
404425
}
405426
else
406427
{
@@ -501,19 +522,19 @@ void Pseudopot_upf::read_pseudo_upf201_nonlocal(std::ifstream& ifs, Atom_pseudo&
501522
}
502523
else if (name[ip] == "angular_momentum")
503524
{
504-
pp.lll[ib] = atoi(val[ip].c_str());
525+
pp.lll[ib] = safe_convert<int>(val[ip], 0);
505526
}
506527
else if (name[ip] == "cutoff_radius_index")
507528
{
508-
this->kbeta[ib] = atoi(val[ip].c_str());
529+
this->kbeta[ib] = safe_convert<int>(val[ip], 0);
509530
}
510531
else if (name[ip] == "cutoff_radius")
511532
{
512-
rcut[ib] = atof(val[ip].c_str());
533+
rcut[ib] = safe_convert<float>(val[ip], 0.0);
513534
}
514535
else if (name[ip] == "ultrasoft_cutoff_radius")
515536
{
516-
rcutus[ib] = atof(val[ip].c_str());
537+
rcutus[ib] = safe_convert<float>(val[ip], 0.0);
517538
}
518539
else
519540
{
@@ -572,11 +593,11 @@ void Pseudopot_upf::read_pseudo_upf201_nonlocal(std::ifstream& ifs, Atom_pseudo&
572593
}
573594
else if (name[ip] == "nqf")
574595
{
575-
nqf = atoi(val[ip].c_str());
596+
nqf = safe_convert<int>(val[ip], 0);
576597
}
577598
else if (name[ip] == "nqlc")
578599
{
579-
pp.nqlc = atoi(val[ip].c_str());
600+
pp.nqlc = safe_convert<int>(val[ip], 0);
580601
}
581602
else
582603
{
@@ -752,31 +773,31 @@ void Pseudopot_upf::read_pseudo_upf201_pswfc(std::ifstream& ifs, Atom_pseudo& pp
752773
}
753774
else if (name[ip] == "l")
754775
{
755-
pp.lchi[iw] = atoi(val[ip].c_str());
776+
pp.lchi[iw] = safe_convert<int>(val[ip], 0);
756777
if (nchi[iw] == -1)
757778
{
758779
nchi[iw] = pp.lchi[iw] - 1;
759780
}
760781
}
761782
else if (name[ip] == "occupation")
762783
{
763-
pp.oc[iw] = atof(val[ip].c_str());
784+
pp.oc[iw] = safe_convert<float>(val[ip], 0.0);
764785
}
765786
else if (name[ip] == "n")
766787
{
767-
nchi[iw] = atoi(val[ip].c_str());
788+
nchi[iw] = safe_convert<int>(val[ip], 0);
768789
}
769790
else if (name[ip] == "pseudo_energy")
770791
{
771-
epseu[iw] = atof(val[ip].c_str());
792+
epseu[iw] = safe_convert<float>(val[ip], 0.0);
772793
}
773794
else if (name[ip] == "cutoff_radius")
774795
{
775-
rcut_chi[iw] = atof(val[ip].c_str());
796+
rcut_chi[iw] = safe_convert<float>(val[ip], 0.0);
776797
}
777798
else if (name[ip] == "ultrasoft_cutoff_radius")
778799
{
779-
rcutus_chi[iw] = atof(val[ip].c_str());
800+
rcutus_chi[iw] = safe_convert<float>(val[ip], 0.0);
780801
}
781802
else
782803
{
@@ -870,19 +891,19 @@ void Pseudopot_upf::read_pseudo_upf201_so(std::ifstream& ifs, Atom_pseudo& pp)
870891
}
871892
else if (name[ip] == "nn")
872893
{
873-
pp.nn[nw] = atoi(val[ip].c_str());
894+
pp.nn[nw] = safe_convert<int>(val[ip], 0);
874895
}
875896
else if (name[ip] == "lchi")
876897
{
877-
pp.lchi[nw] = atoi(val[ip].c_str());
898+
pp.lchi[nw] = safe_convert<int>(val[ip], 0);
878899
}
879900
else if (name[ip] == "jchi")
880901
{
881-
pp.jchi[nw] = atof(val[ip].c_str());
902+
pp.jchi[nw] = safe_convert<float>(val[ip], 0.0);
882903
}
883904
else if (name[ip] == "oc")
884905
{
885-
pp.oc[nw] = atof(val[ip].c_str());
906+
pp.oc[nw] = safe_convert<float>(val[ip], 0.0);
886907
}
887908
else
888909
{
@@ -907,11 +928,11 @@ void Pseudopot_upf::read_pseudo_upf201_so(std::ifstream& ifs, Atom_pseudo& pp)
907928
}
908929
else if (name[ip] == "lll")
909930
{
910-
pp.lll[nb] = atoi(val[ip].c_str());
931+
pp.lll[nb] = safe_convert<int>(val[ip], 0);
911932
}
912933
else if (name[ip] == "jjj")
913934
{
914-
pp.jjj[nb] = atof(val[ip].c_str());
935+
pp.jjj[nb] = safe_convert<float>(val[ip], 0.0);
915936
}
916937
else
917938
{

source/source_hsolver/module_genelpa/utils.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,31 @@ void initBlacsGrid(int loglevel,
8282
}
8383
#endif
8484

85+
// open helper function
86+
template<typename FileStreamType>
87+
void safe_open_file(FileStreamType& file, const std::string& filename,
88+
std::ios_base::openmode mode = std::ios_base::in,
89+
int myid = 0, bool is_parallel = true) {
90+
file.open(filename, mode);
91+
92+
if (!file.is_open()) {
93+
std::stringstream ss;
94+
ss << "ERROR: Cannot open file '" << filename << "'";
95+
if (errno != 0) {
96+
ss << " - " << strerror(errno);
97+
}
98+
99+
std::cerr << ss.str() << std::endl;
100+
101+
// handle MPI problem
102+
if (is_parallel) {
103+
MPI_Abort(MPI_COMM_WORLD, 1);
104+
} else {
105+
throw std::runtime_error(ss.str());
106+
}
107+
}
108+
}
109+
85110
// load matrix from the file
86111
void loadMatrix(const char FileName[], int nFull, double* a, int* desca, int blacs_ctxt)
87112
{
@@ -92,7 +117,7 @@ void loadMatrix(const char FileName[], int nFull, double* a, int* desca, int bla
92117
const int ROOT_PROC = 0;
93118
std::ifstream matrixFile;
94119
if (myid == ROOT_PROC)
95-
matrixFile.open(FileName);
120+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
96121

97122
double* b = nullptr; // buffer
98123
const int MAX_BUFFER_SIZE = 1e9; // max buffer size is 1GB
@@ -146,7 +171,7 @@ void saveLocalMatrix(const char filePrefix[], int narows, int nacols, double* a)
146171
#endif
147172

148173
sprintf(FileName, "%s_%3.3d.dat", filePrefix, myid);
149-
matrixFile.open(FileName);
174+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
150175
matrixFile.flags(std::ios_base::scientific);
151176
matrixFile.precision(17);
152177
matrixFile.width(24);
@@ -173,7 +198,7 @@ void saveMatrix(const char FileName[], int nFull, double* a, int* desca, int bla
173198
std::ofstream matrixFile;
174199
if (myid == ROOT_PROC) // setup saved matrix format
175200
{
176-
matrixFile.open(FileName);
201+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
177202
matrixFile.flags(std::ios_base::scientific);
178203
matrixFile.precision(17);
179204
matrixFile.width(24);
@@ -229,7 +254,7 @@ void loadMatrix(const char FileName[], int nFull, std::complex<double>* a, int*
229254
const int ROOT_PROC = 0;
230255
std::ifstream matrixFile;
231256
if (myid == ROOT_PROC)
232-
matrixFile.open(FileName);
257+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
233258

234259
std::complex<double>* b; // buffer
235260
const int MAX_BUFFER_SIZE = 1e9; // max buffer size is 1GB
@@ -284,7 +309,7 @@ void saveLocalMatrix(const char filePrefix[], int narows, int nacols, std::compl
284309
#endif
285310

286311
sprintf(FileName, "%s_%3.3d.dat", filePrefix, myid);
287-
matrixFile.open(FileName);
312+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
288313
matrixFile.flags(std::ios_base::scientific);
289314
matrixFile.precision(17);
290315
matrixFile.width(24);
@@ -311,7 +336,7 @@ void saveMatrix(const char FileName[], int nFull, std::complex<double>* a, int*
311336
std::ofstream matrixFile;
312337
if (myid == ROOT_PROC) // setup saved matrix format
313338
{
314-
matrixFile.open(FileName);
339+
safe_open_file(matrixFile, FileName, std::ios_base::in, myid, true);
315340
matrixFile.flags(std::ios_base::scientific);
316341
matrixFile.precision(17);
317342
matrixFile.width(24);

0 commit comments

Comments
 (0)