Skip to content

Commit 4087ffd

Browse files
authored
Feature: (minor) support various boolean expressions for outputting flags (#5489)
* Feature: (minor) support various boolean expressions for outputting flags * ouch, the constexpr of std::array::size() is of C++17 * remove libcomm and libri manually
1 parent c440291 commit 4087ffd

File tree

8 files changed

+69
-119
lines changed

8 files changed

+69
-119
lines changed

source/module_base/formatter.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ class FmtCore
145145
[&delim](const std::string& acc, const std::string& s) { return acc + delim + s; });
146146
}
147147

148+
static std::string upper(const std::string& in)
149+
{
150+
std::string dst = in;
151+
std::transform(dst.begin(), dst.end(), dst.begin(), ::toupper);
152+
return dst;
153+
}
154+
155+
static std::string lower(const std::string& in)
156+
{
157+
std::string dst = in;
158+
std::transform(dst.begin(), dst.end(), dst.begin(), ::tolower);
159+
return dst;
160+
}
161+
148162
private:
149163
std::string fmt_;
150164
template<typename T>

source/module_io/read_input.cpp

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,79 +5,52 @@
55
#include <fstream>
66
#include <iostream>
77
#include <sstream>
8+
#include <array>
9+
#include <vector>
10+
#include <cassert>
11+
#include "module_base/formatter.h"
812
#include "module_base/global_file.h"
913
#include "module_base/global_function.h"
1014
#include "module_base/tool_quit.h"
1115
#include "module_base/tool_title.h"
1216
namespace ModuleIO
1317
{
1418

15-
void strtolower(char* sa, char* sb)
19+
std::string longstring(const std::vector<std::string>& words)
1620
{
17-
char c;
18-
int len = strlen(sa);
19-
for (int i = 0; i < len; i++)
20-
{
21-
c = sa[i];
22-
sb[i] = tolower(c);
23-
}
24-
sb[len] = '\0';
21+
return FmtCore::join(" ", words);
2522
}
2623

27-
std::string longstring(const std::vector<std::string>& str_values)
24+
bool assume_as_boolean(const std::string& val)
2825
{
29-
std::string output;
30-
output = "";
31-
const size_t length = str_values.size();
32-
for (int i = 0; i < length; ++i)
33-
{
34-
output += str_values[i];
35-
if (i != length - 1)
36-
{
37-
output += " ";
38-
}
39-
}
40-
return output;
41-
}
26+
const std::string val_ = FmtCore::lower(val);
4227

43-
bool convert_bool(std::string str)
44-
{
45-
for (auto& i: str)
46-
{
47-
i = tolower(i);
48-
}
49-
if (str == "true")
50-
{
51-
return true;
52-
}
53-
else if (str == "false")
54-
{
55-
return false;
56-
}
57-
else if (str == "1")
58-
{
59-
return true;
60-
}
61-
else if (str == "0")
62-
{
63-
return false;
64-
}
65-
else if (str == "t")
28+
const std::array<std::string, 7> t_ = {"true", "1", "t", "yes", "y", "on", ".true."};
29+
const std::array<std::string, 7> f_ = {"false", "0", "f", "no", "n", "off", ".false."};
30+
// This will work because std::array<T, N>::size() is a constexpr function
31+
// Ouch it is of C++17 standard...
32+
// static_assert(t_.size() == f_.size(), "t_ and f_ must have the same lengths");
33+
#ifdef __DEBUG // C++11 can do this
34+
assert(t_.size() == f_.size());
35+
#endif
36+
37+
if (std::find(t_.begin(), t_.end(), val_) != t_.end())
6638
{
6739
return true;
6840
}
69-
else if (str == "f")
41+
else if (std::find(f_.begin(), f_.end(), val_) != f_.end())
7042
{
7143
return false;
7244
}
7345
else
7446
{
75-
std::string warningstr = "Bad boolean parameter ";
76-
warningstr.append(str);
77-
warningstr.append(", please check the input parameters in file INPUT");
78-
ModuleBase::WARNING_QUIT("Input", warningstr);
47+
std::string warnmsg = "Bad boolean parameter ";
48+
warnmsg.append(val);
49+
warnmsg.append(", please check the input parameters in file INPUT");
50+
ModuleBase::WARNING_QUIT("Input", warnmsg);
7951
}
8052
}
53+
8154
std::string to_dir(const std::string& str)
8255
{
8356
std::string str_dir = str;
@@ -216,8 +189,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
216189
ifs.clear();
217190
ifs.seekg(0);
218191

219-
char word[80];
220-
char word1[80];
192+
std::string word, word1;
221193
int ierr = 0;
222194

223195
// ifs >> std::setiosflags(ios::uppercase);
@@ -226,7 +198,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
226198
{
227199
ifs >> word;
228200
ifs.ignore(150, '\n');
229-
if (strcmp(word, "INPUT_PARAMETERS") == 0)
201+
if (word == "INPUT_PARAMETERS")
230202
{
231203
ierr = 1;
232204
break;
@@ -247,10 +219,8 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
247219
while (ifs.good())
248220
{
249221
ifs >> word1;
250-
if (ifs.eof()) {
251-
break;
252-
}
253-
strtolower(word1, word);
222+
if (ifs.eof()) { break; }
223+
word = FmtCore::lower(word1);
254224
auto it = std::find_if(input_lists.begin(),
255225
input_lists.end(),
256226
[&word](const std::pair<std::string, Input_Item>& item) { return item.first == word; });
@@ -311,7 +281,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
311281
Input_Item* resetvalue_item = &(input_item.second);
312282
if (resetvalue_item->reset_value != nullptr) {
313283
resetvalue_item->reset_value(*resetvalue_item, param);
314-
}
284+
}
315285
}
316286
this->set_globalv(param);
317287

@@ -327,7 +297,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
327297
Input_Item* checkvalue_item = &(input_item.second);
328298
if (checkvalue_item->check_value != nullptr) {
329299
checkvalue_item->check_value(*checkvalue_item, param);
330-
}
300+
}
331301
}
332302
}
333303

@@ -505,12 +475,6 @@ void ReadInput::add_item(const Input_Item& item)
505475
}
506476
}
507477

508-
bool find_str(const std::vector<std::string>& strings, const std::string& strToFind)
509-
{
510-
auto it = std::find(strings.begin(), strings.end(), strToFind);
511-
return it != strings.end();
512-
}
513-
514478
std::string nofound_str(std::vector<std::string> init_chgs, const std::string& str)
515479
{
516480
std::string warningstr = "The parameter ";

source/module_io/read_input.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,10 @@ class ReadInput
139139
std::vector<std::function<void(Parameter&)>> bcastfuncs;
140140
};
141141

142-
// convert string to lower case
143-
void strtolower(char* sa, char* sb);
144142
// convert string vector to a long string
145143
std::string longstring(const std::vector<std::string>& str_values);
146144
// convert string to bool
147-
bool convert_bool(std::string str);
148-
// if find a string in a vector of strings
149-
bool find_str(const std::vector<std::string>& strings, const std::string& strToFind);
145+
bool assume_as_boolean(const std::string& val);
150146
// convert to directory format
151147
std::string to_dir(const std::string& str);
152148
// return a warning string if the string is not found in the vector

source/module_io/read_input_item_elec_stru.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ void ReadInput::item_elec_stru()
7676

7777
if (para.input.basis_type == "pw")
7878
{
79-
if (!find_str(pw_solvers, ks_solver))
79+
if (std::find(pw_solvers.begin(), pw_solvers.end(), ks_solver) == pw_solvers.end())
8080
{
8181
const std::string warningstr = "For PW basis: " + nofound_str(pw_solvers, "ks_solver");
8282
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
8383
}
8484
}
8585
else if (para.input.basis_type == "lcao")
8686
{
87-
if (!find_str(lcao_solvers, ks_solver))
87+
if (std::find(lcao_solvers.begin(), lcao_solvers.end(), ks_solver) == lcao_solvers.end())
8888
{
8989
const std::string warningstr = "For LCAO basis: " + nofound_str(lcao_solvers, "ks_solver");
9090
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
@@ -163,7 +163,7 @@ void ReadInput::item_elec_stru()
163163
};
164164
item.check_value = [](const Input_Item& item, const Parameter& para) {
165165
const std::vector<std::string> basis_types = {"pw", "lcao_in_pw", "lcao"};
166-
if (!find_str(basis_types, para.input.basis_type))
166+
if (std::find(basis_types.begin(), basis_types.end(), para.input.basis_type) == basis_types.end())
167167
{
168168
const std::string warningstr = nofound_str(basis_types, "basis_type");
169169
ModuleBase::WARNING_QUIT("ReadInput", warningstr);

source/module_io/read_input_item_output.cpp

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void ReadInput::item_output()
1111
item.annotation = "output the structure files after each ion step";
1212
item.reset_value = [](const Input_Item& item, Parameter& para) {
1313
const std::vector<std::string> offlist = {"nscf", "get_S", "get_pchg", "get_wf"};
14-
if (find_str(offlist, para.input.calculation))
14+
if (std::find(offlist.begin(), offlist.end(), para.input.calculation) != offlist.end())
1515
{
1616
para.input.out_stru = false;
1717
}
@@ -96,21 +96,13 @@ void ReadInput::item_output()
9696
Input_Item item("out_band");
9797
item.annotation = "output energy and band structure (with precision 8)";
9898
item.read_value = [](const Input_Item& item, Parameter& para) {
99-
size_t count = item.get_size();
100-
if (count == 1)
101-
{
102-
para.input.out_band[0] = std::stoi(item.str_values[0]);
103-
para.input.out_band[1] = 8;
104-
}
105-
else if (count == 2)
106-
{
107-
para.input.out_band[0] = std::stoi(item.str_values[0]);
108-
para.input.out_band[1] = std::stoi(item.str_values[1]);
109-
}
110-
else
99+
const size_t count = item.get_size();
100+
if (count != 1 && count != 2)
111101
{
112102
ModuleBase::WARNING_QUIT("ReadInput", "out_band should have 1 or 2 values");
113103
}
104+
para.input.out_band[0] = assume_as_boolean(item.str_values[0]);
105+
para.input.out_band[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
114106
};
115107
item.reset_value = [](const Input_Item& item, Parameter& para) {
116108
if (para.input.calculation == "get_wf" || para.input.calculation == "get_pchg")
@@ -239,21 +231,13 @@ void ReadInput::item_output()
239231
Input_Item item("out_mat_hs");
240232
item.annotation = "output H and S matrix (with precision 8)";
241233
item.read_value = [](const Input_Item& item, Parameter& para) {
242-
size_t count = item.get_size();
243-
if (count == 1)
244-
{
245-
para.input.out_mat_hs[0] = std::stoi(item.str_values[0]);
246-
para.input.out_mat_hs[1] = 8;
247-
}
248-
else if (count == 2)
249-
{
250-
para.input.out_mat_hs[0] = std::stoi(item.str_values[0]);
251-
para.input.out_mat_hs[1] = std::stoi(item.str_values[1]);
252-
}
253-
else
234+
const size_t count = item.get_size();
235+
if (count != 1 && count != 2)
254236
{
255237
ModuleBase::WARNING_QUIT("ReadInput", "out_mat_hs should have 1 or 2 values");
256238
}
239+
para.input.out_mat_hs[0] = assume_as_boolean(item.str_values[0]);
240+
para.input.out_mat_hs[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
257241
};
258242
item.reset_value = [](const Input_Item& item, Parameter& para) {
259243
if (para.input.qo_switch)
@@ -268,21 +252,13 @@ void ReadInput::item_output()
268252
Input_Item item("out_mat_tk");
269253
item.annotation = "output T(k)";
270254
item.read_value = [](const Input_Item& item, Parameter& para) {
271-
size_t count = item.get_size();
272-
if (count == 1)
273-
{
274-
para.input.out_mat_tk[0] = std::stoi(item.str_values[0]);
275-
para.input.out_mat_tk[1] = 8;
276-
}
277-
else if (count == 2)
278-
{
279-
para.input.out_mat_tk[0] = std::stoi(item.str_values[0]);
280-
para.input.out_mat_tk[1] = std::stoi(item.str_values[1]);
281-
}
282-
else
255+
const size_t count = item.get_size();
256+
if (count != 1 && count != 2)
283257
{
284258
ModuleBase::WARNING_QUIT("ReadInput", "out_mat_tk should have 1 or 2 values");
285259
}
260+
para.input.out_mat_tk[0] = assume_as_boolean(item.str_values[0]);
261+
para.input.out_mat_tk[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
286262
};
287263
sync_intvec(input.out_mat_tk, 2, 0);
288264
this->add_item(item);

source/module_io/read_input_item_relax.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ void ReadInput::item_relax()
1313
read_sync_string(input.relax_method);
1414
item.check_value = [](const Input_Item& item, const Parameter& para) {
1515
const std::vector<std::string> relax_methods = {"cg", "bfgs", "sd", "cg_bfgs"};
16-
if (!find_str(relax_methods, para.input.relax_method))
16+
if (std::find(relax_methods.begin(), relax_methods.end(), para.input.relax_method) == relax_methods.end())
1717
{
1818
const std::string warningstr = nofound_str(relax_methods, "relax_method");
1919
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
@@ -52,7 +52,7 @@ void ReadInput::item_relax()
5252
const std::string& calculation = para.input.calculation;
5353
const std::vector<std::string> singlelist
5454
= {"scf", "nscf", "get_S", "get_pchg", "get_wf", "test_memory", "test_neighbour", "gen_bessel"};
55-
if (find_str(singlelist, calculation))
55+
if (std::find(singlelist.begin(), singlelist.end(), calculation) != singlelist.end())
5656
{
5757
para.input.relax_nmax = 1;
5858
}

source/module_io/read_input_item_system.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void ReadInput::item_system()
8080
"get_wf",
8181
"get_pchg",
8282
"gen_bessel"};
83-
if (!find_str(callist, calculation))
83+
if (std::find(callist.begin(), callist.end(), calculation) == callist.end())
8484
{
8585
const std::string warningstr = nofound_str(callist, "calculation");
8686
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
@@ -111,7 +111,7 @@ void ReadInput::item_system()
111111
read_sync_string(input.esolver_type);
112112
item.check_value = [](const Input_Item& item, const Parameter& para) {
113113
const std::vector<std::string> esolver_types = { "ksdft", "sdft", "ofdft", "tddft", "lj", "dp", "lr", "ks-lr" };
114-
if (!find_str(esolver_types, para.input.esolver_type))
114+
if (std::find(esolver_types.begin(), esolver_types.end(), para.input.esolver_type) == esolver_types.end())
115115
{
116116
const std::string warningstr = nofound_str(esolver_types, "esolver_type");
117117
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
@@ -208,15 +208,15 @@ void ReadInput::item_system()
208208
item.reset_value = [](const Input_Item& item, Parameter& para) {
209209
std::vector<std::string> use_force = {"cell-relax", "relax", "md"};
210210
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "nscf", "get_S"};
211-
if (find_str(use_force, para.input.calculation))
211+
if (std::find(use_force.begin(), use_force.end(), para.input.calculation) != use_force.end())
212212
{
213213
if (!para.input.cal_force)
214214
{
215215
ModuleBase::GlobalFunc::AUTO_SET("cal_force", "true");
216216
}
217217
para.input.cal_force = true;
218218
}
219-
else if (find_str(not_use_force, para.input.calculation))
219+
else if (std::find(not_use_force.begin(), not_use_force.end(), para.input.calculation) != not_use_force.end())
220220
{
221221
if (para.input.cal_force)
222222
{
@@ -538,7 +538,7 @@ void ReadInput::item_system()
538538
};
539539
item.check_value = [](const Input_Item& item, const Parameter& para) {
540540
const std::vector<std::string> init_chgs = {"atomic", "file", "wfc", "auto"};
541-
if (!find_str(init_chgs, para.input.init_chg))
541+
if (std::find(init_chgs.begin(), init_chgs.end(), para.input.init_chg) == init_chgs.end())
542542
{
543543
const std::string warningstr = nofound_str(init_chgs, "init_chg");
544544
ModuleBase::WARNING_QUIT("ReadInput", warningstr);

source/module_io/read_input_tool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#define strvalue item.str_values[0]
99
#define intvalue std::stoi(item.str_values[0])
1010
#define doublevalue std::stod(item.str_values[0])
11-
#define boolvalue convert_bool(item.str_values[0])
11+
#define boolvalue assume_as_boolean(item.str_values[0])
1212

1313
#ifdef __MPI
1414
#define add_double_bcast(PARAMETER) \

0 commit comments

Comments
 (0)