diff --git a/source/module_base/formatter.h b/source/module_base/formatter.h index 614988f7c3..963964f0ef 100644 --- a/source/module_base/formatter.h +++ b/source/module_base/formatter.h @@ -145,6 +145,20 @@ class FmtCore [&delim](const std::string& acc, const std::string& s) { return acc + delim + s; }); } + static std::string upper(const std::string& in) + { + std::string dst = in; + std::transform(dst.begin(), dst.end(), dst.begin(), ::toupper); + return dst; + } + + static std::string lower(const std::string& in) + { + std::string dst = in; + std::transform(dst.begin(), dst.end(), dst.begin(), ::tolower); + return dst; + } + private: std::string fmt_; template diff --git a/source/module_io/read_input.cpp b/source/module_io/read_input.cpp index 21235eb461..50bf47f4d1 100644 --- a/source/module_io/read_input.cpp +++ b/source/module_io/read_input.cpp @@ -5,6 +5,10 @@ #include #include #include +#include +#include +#include +#include "module_base/formatter.h" #include "module_base/global_file.h" #include "module_base/global_function.h" #include "module_base/tool_quit.h" @@ -12,72 +16,41 @@ namespace ModuleIO { -void strtolower(char* sa, char* sb) +std::string longstring(const std::vector& words) { - char c; - int len = strlen(sa); - for (int i = 0; i < len; i++) - { - c = sa[i]; - sb[i] = tolower(c); - } - sb[len] = '\0'; + return FmtCore::join(" ", words); } -std::string longstring(const std::vector& str_values) +bool assume_as_boolean(const std::string& val) { - std::string output; - output = ""; - const size_t length = str_values.size(); - for (int i = 0; i < length; ++i) - { - output += str_values[i]; - if (i != length - 1) - { - output += " "; - } - } - return output; -} + const std::string val_ = FmtCore::lower(val); -bool convert_bool(std::string str) -{ - for (auto& i: str) - { - i = tolower(i); - } - if (str == "true") - { - return true; - } - else if (str == "false") - { - return false; - } - else if (str == "1") - { - return true; - } - else if (str == "0") - { - return false; - } - else if (str == "t") + const std::array t_ = {"true", "1", "t", "yes", "y", "on", ".true."}; + const std::array f_ = {"false", "0", "f", "no", "n", "off", ".false."}; + // This will work because std::array::size() is a constexpr function + // Ouch it is of C++17 standard... + // static_assert(t_.size() == f_.size(), "t_ and f_ must have the same lengths"); +#ifdef __DEBUG // C++11 can do this + assert(t_.size() == f_.size()); +#endif + + if (std::find(t_.begin(), t_.end(), val_) != t_.end()) { return true; } - else if (str == "f") + else if (std::find(f_.begin(), f_.end(), val_) != f_.end()) { return false; } else { - std::string warningstr = "Bad boolean parameter "; - warningstr.append(str); - warningstr.append(", please check the input parameters in file INPUT"); - ModuleBase::WARNING_QUIT("Input", warningstr); + std::string warnmsg = "Bad boolean parameter "; + warnmsg.append(val); + warnmsg.append(", please check the input parameters in file INPUT"); + ModuleBase::WARNING_QUIT("Input", warnmsg); } } + std::string to_dir(const std::string& str) { std::string str_dir = str; @@ -216,8 +189,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) ifs.clear(); ifs.seekg(0); - char word[80]; - char word1[80]; + std::string word, word1; int ierr = 0; // ifs >> std::setiosflags(ios::uppercase); @@ -226,7 +198,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) { ifs >> word; ifs.ignore(150, '\n'); - if (strcmp(word, "INPUT_PARAMETERS") == 0) + if (word == "INPUT_PARAMETERS") { ierr = 1; break; @@ -247,10 +219,8 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) while (ifs.good()) { ifs >> word1; - if (ifs.eof()) { - break; -} - strtolower(word1, word); + if (ifs.eof()) { break; } + word = FmtCore::lower(word1); auto it = std::find_if(input_lists.begin(), input_lists.end(), [&word](const std::pair& item) { return item.first == word; }); @@ -311,7 +281,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) Input_Item* resetvalue_item = &(input_item.second); if (resetvalue_item->reset_value != nullptr) { resetvalue_item->reset_value(*resetvalue_item, param); -} + } } this->set_globalv(param); @@ -327,7 +297,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) Input_Item* checkvalue_item = &(input_item.second); if (checkvalue_item->check_value != nullptr) { checkvalue_item->check_value(*checkvalue_item, param); -} + } } } @@ -505,12 +475,6 @@ void ReadInput::add_item(const Input_Item& item) } } -bool find_str(const std::vector& strings, const std::string& strToFind) -{ - auto it = std::find(strings.begin(), strings.end(), strToFind); - return it != strings.end(); -} - std::string nofound_str(std::vector init_chgs, const std::string& str) { std::string warningstr = "The parameter "; diff --git a/source/module_io/read_input.h b/source/module_io/read_input.h index 0bd8745ad7..a39d043fe7 100644 --- a/source/module_io/read_input.h +++ b/source/module_io/read_input.h @@ -139,14 +139,10 @@ class ReadInput std::vector> bcastfuncs; }; -// convert string to lower case -void strtolower(char* sa, char* sb); // convert string vector to a long string std::string longstring(const std::vector& str_values); // convert string to bool -bool convert_bool(std::string str); -// if find a string in a vector of strings -bool find_str(const std::vector& strings, const std::string& strToFind); +bool assume_as_boolean(const std::string& val); // convert to directory format std::string to_dir(const std::string& str); // return a warning string if the string is not found in the vector diff --git a/source/module_io/read_input_item_elec_stru.cpp b/source/module_io/read_input_item_elec_stru.cpp index 6b6b9944e7..cf7ec3cb89 100644 --- a/source/module_io/read_input_item_elec_stru.cpp +++ b/source/module_io/read_input_item_elec_stru.cpp @@ -76,7 +76,7 @@ void ReadInput::item_elec_stru() if (para.input.basis_type == "pw") { - if (!find_str(pw_solvers, ks_solver)) + if (std::find(pw_solvers.begin(), pw_solvers.end(), ks_solver) == pw_solvers.end()) { const std::string warningstr = "For PW basis: " + nofound_str(pw_solvers, "ks_solver"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); @@ -84,7 +84,7 @@ void ReadInput::item_elec_stru() } else if (para.input.basis_type == "lcao") { - if (!find_str(lcao_solvers, ks_solver)) + if (std::find(lcao_solvers.begin(), lcao_solvers.end(), ks_solver) == lcao_solvers.end()) { const std::string warningstr = "For LCAO basis: " + nofound_str(lcao_solvers, "ks_solver"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); @@ -163,7 +163,7 @@ void ReadInput::item_elec_stru() }; item.check_value = [](const Input_Item& item, const Parameter& para) { const std::vector basis_types = {"pw", "lcao_in_pw", "lcao"}; - if (!find_str(basis_types, para.input.basis_type)) + if (std::find(basis_types.begin(), basis_types.end(), para.input.basis_type) == basis_types.end()) { const std::string warningstr = nofound_str(basis_types, "basis_type"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); diff --git a/source/module_io/read_input_item_output.cpp b/source/module_io/read_input_item_output.cpp index d0001e805d..03742d24d0 100644 --- a/source/module_io/read_input_item_output.cpp +++ b/source/module_io/read_input_item_output.cpp @@ -11,7 +11,7 @@ void ReadInput::item_output() item.annotation = "output the structure files after each ion step"; item.reset_value = [](const Input_Item& item, Parameter& para) { const std::vector offlist = {"nscf", "get_S", "get_pchg", "get_wf"}; - if (find_str(offlist, para.input.calculation)) + if (std::find(offlist.begin(), offlist.end(), para.input.calculation) != offlist.end()) { para.input.out_stru = false; } @@ -96,21 +96,13 @@ void ReadInput::item_output() Input_Item item("out_band"); item.annotation = "output energy and band structure (with precision 8)"; item.read_value = [](const Input_Item& item, Parameter& para) { - size_t count = item.get_size(); - if (count == 1) - { - para.input.out_band[0] = std::stoi(item.str_values[0]); - para.input.out_band[1] = 8; - } - else if (count == 2) - { - para.input.out_band[0] = std::stoi(item.str_values[0]); - para.input.out_band[1] = std::stoi(item.str_values[1]); - } - else + const size_t count = item.get_size(); + if (count != 1 && count != 2) { ModuleBase::WARNING_QUIT("ReadInput", "out_band should have 1 or 2 values"); } + para.input.out_band[0] = assume_as_boolean(item.str_values[0]); + para.input.out_band[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8; }; item.reset_value = [](const Input_Item& item, Parameter& para) { if (para.input.calculation == "get_wf" || para.input.calculation == "get_pchg") @@ -239,21 +231,13 @@ void ReadInput::item_output() Input_Item item("out_mat_hs"); item.annotation = "output H and S matrix (with precision 8)"; item.read_value = [](const Input_Item& item, Parameter& para) { - size_t count = item.get_size(); - if (count == 1) - { - para.input.out_mat_hs[0] = std::stoi(item.str_values[0]); - para.input.out_mat_hs[1] = 8; - } - else if (count == 2) - { - para.input.out_mat_hs[0] = std::stoi(item.str_values[0]); - para.input.out_mat_hs[1] = std::stoi(item.str_values[1]); - } - else + const size_t count = item.get_size(); + if (count != 1 && count != 2) { ModuleBase::WARNING_QUIT("ReadInput", "out_mat_hs should have 1 or 2 values"); } + para.input.out_mat_hs[0] = assume_as_boolean(item.str_values[0]); + para.input.out_mat_hs[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8; }; item.reset_value = [](const Input_Item& item, Parameter& para) { if (para.input.qo_switch) @@ -268,21 +252,13 @@ void ReadInput::item_output() Input_Item item("out_mat_tk"); item.annotation = "output T(k)"; item.read_value = [](const Input_Item& item, Parameter& para) { - size_t count = item.get_size(); - if (count == 1) - { - para.input.out_mat_tk[0] = std::stoi(item.str_values[0]); - para.input.out_mat_tk[1] = 8; - } - else if (count == 2) - { - para.input.out_mat_tk[0] = std::stoi(item.str_values[0]); - para.input.out_mat_tk[1] = std::stoi(item.str_values[1]); - } - else + const size_t count = item.get_size(); + if (count != 1 && count != 2) { ModuleBase::WARNING_QUIT("ReadInput", "out_mat_tk should have 1 or 2 values"); } + para.input.out_mat_tk[0] = assume_as_boolean(item.str_values[0]); + para.input.out_mat_tk[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8; }; sync_intvec(input.out_mat_tk, 2, 0); this->add_item(item); diff --git a/source/module_io/read_input_item_relax.cpp b/source/module_io/read_input_item_relax.cpp index c9545576c2..f222b25ef9 100644 --- a/source/module_io/read_input_item_relax.cpp +++ b/source/module_io/read_input_item_relax.cpp @@ -13,7 +13,7 @@ void ReadInput::item_relax() read_sync_string(input.relax_method); item.check_value = [](const Input_Item& item, const Parameter& para) { const std::vector relax_methods = {"cg", "bfgs", "sd", "cg_bfgs"}; - if (!find_str(relax_methods, para.input.relax_method)) + if (std::find(relax_methods.begin(), relax_methods.end(), para.input.relax_method) == relax_methods.end()) { const std::string warningstr = nofound_str(relax_methods, "relax_method"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); @@ -52,7 +52,7 @@ void ReadInput::item_relax() const std::string& calculation = para.input.calculation; const std::vector singlelist = {"scf", "nscf", "get_S", "get_pchg", "get_wf", "test_memory", "test_neighbour", "gen_bessel"}; - if (find_str(singlelist, calculation)) + if (std::find(singlelist.begin(), singlelist.end(), calculation) != singlelist.end()) { para.input.relax_nmax = 1; } diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index 1ceae0b79c..7398dc2e6a 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -80,7 +80,7 @@ void ReadInput::item_system() "get_wf", "get_pchg", "gen_bessel"}; - if (!find_str(callist, calculation)) + if (std::find(callist.begin(), callist.end(), calculation) == callist.end()) { const std::string warningstr = nofound_str(callist, "calculation"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); @@ -111,7 +111,7 @@ void ReadInput::item_system() read_sync_string(input.esolver_type); item.check_value = [](const Input_Item& item, const Parameter& para) { const std::vector esolver_types = { "ksdft", "sdft", "ofdft", "tddft", "lj", "dp", "lr", "ks-lr" }; - if (!find_str(esolver_types, para.input.esolver_type)) + if (std::find(esolver_types.begin(), esolver_types.end(), para.input.esolver_type) == esolver_types.end()) { const std::string warningstr = nofound_str(esolver_types, "esolver_type"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); @@ -208,7 +208,7 @@ void ReadInput::item_system() item.reset_value = [](const Input_Item& item, Parameter& para) { std::vector use_force = {"cell-relax", "relax", "md"}; std::vector not_use_force = {"get_wf", "get_pchg", "nscf", "get_S"}; - if (find_str(use_force, para.input.calculation)) + if (std::find(use_force.begin(), use_force.end(), para.input.calculation) != use_force.end()) { if (!para.input.cal_force) { @@ -216,7 +216,7 @@ void ReadInput::item_system() } para.input.cal_force = true; } - else if (find_str(not_use_force, para.input.calculation)) + else if (std::find(not_use_force.begin(), not_use_force.end(), para.input.calculation) != not_use_force.end()) { if (para.input.cal_force) { @@ -538,7 +538,7 @@ void ReadInput::item_system() }; item.check_value = [](const Input_Item& item, const Parameter& para) { const std::vector init_chgs = {"atomic", "file", "wfc", "auto"}; - if (!find_str(init_chgs, para.input.init_chg)) + if (std::find(init_chgs.begin(), init_chgs.end(), para.input.init_chg) == init_chgs.end()) { const std::string warningstr = nofound_str(init_chgs, "init_chg"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); diff --git a/source/module_io/read_input_tool.h b/source/module_io/read_input_tool.h index 9df14d0e67..665d045ce8 100644 --- a/source/module_io/read_input_tool.h +++ b/source/module_io/read_input_tool.h @@ -8,7 +8,7 @@ #define strvalue item.str_values[0] #define intvalue std::stoi(item.str_values[0]) #define doublevalue std::stod(item.str_values[0]) -#define boolvalue convert_bool(item.str_values[0]) +#define boolvalue assume_as_boolean(item.str_values[0]) #ifdef __MPI #define add_double_bcast(PARAMETER) \