Skip to content

Commit b1eb885

Browse files
committed
Merge branch 'develop' of https://gitee.com/deepmodeling/abacus-develop into rdmft
2 parents ad706c3 + 5816785 commit b1eb885

30 files changed

+763
-180
lines changed

source/module_base/formatter.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ class FmtCore
145145
[&delim](const std::string& acc, const std::string& s) { return acc + delim + s; });
146146
}
147147

148+
static std::string upper(const std::string& in)
149+
{
150+
std::string dst = in;
151+
std::transform(dst.begin(), dst.end(), dst.begin(), ::toupper);
152+
return dst;
153+
}
154+
155+
static std::string lower(const std::string& in)
156+
{
157+
std::string dst = in;
158+
std::transform(dst.begin(), dst.end(), dst.begin(), ::tolower);
159+
return dst;
160+
}
161+
148162
private:
149163
std::string fmt_;
150164
template<typename T>

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -646,19 +646,22 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
646646

647647
#ifdef __EXX
648648
// calculate exact-exchange
649-
if (GlobalC::exx_info.info_ri.real_number)
650-
{
651-
this->exd->exx_eachiterinit(istep,
652-
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
653-
this->kv,
654-
iter);
655-
}
656-
else
649+
if (PARAM.inp.calculation != "nscf")
657650
{
658-
this->exc->exx_eachiterinit(istep,
659-
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
660-
this->kv,
661-
iter);
651+
if (GlobalC::exx_info.info_ri.real_number)
652+
{
653+
this->exd->exx_eachiterinit(istep,
654+
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
655+
this->kv,
656+
iter);
657+
}
658+
else
659+
{
660+
this->exc->exx_eachiterinit(istep,
661+
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
662+
this->kv,
663+
iter);
664+
}
662665
}
663666
#endif
664667

@@ -740,13 +743,16 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density_single(int istep, int iter, double
740743

741744
// 5) what's the exd used for?
742745
#ifdef __EXX
743-
if (GlobalC::exx_info.info_ri.real_number)
746+
if (PARAM.inp.calculation != "nscf")
744747
{
745-
this->exd->exx_hamilt2density(*this->pelec, this->pv, iter);
746-
}
747-
else
748-
{
749-
this->exc->exx_hamilt2density(*this->pelec, this->pv, iter);
748+
if (GlobalC::exx_info.info_ri.real_number)
749+
{
750+
this->exd->exx_hamilt2density(*this->pelec, this->pv, iter);
751+
}
752+
else
753+
{
754+
this->exc->exx_hamilt2density(*this->pelec, this->pv, iter);
755+
}
750756
}
751757
#endif
752758

@@ -951,11 +957,29 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(const int istep, int& iter)
951957
if( GlobalC::exx_info.info_global.cal_exx && this->conv_esolver ) one_step_exx = true;
952958

953959
// 3) save exx matrix
954-
if (GlobalC::exx_info.info_global.cal_exx)
960+
if (PARAM.inp.calculation != "nscf")
955961
{
956-
GlobalC::exx_info.info_ri.real_number ?
957-
this->exd->exx_iter_finish(this->kv, GlobalC::ucell, *this->p_hamilt, *this->pelec, *this->p_chgmix, this->scf_ene_thr, iter, istep, this->conv_esolver) :
958-
this->exc->exx_iter_finish(this->kv, GlobalC::ucell, *this->p_hamilt, *this->pelec, *this->p_chgmix, this->scf_ene_thr, iter, istep, this->conv_esolver);
962+
if (GlobalC::exx_info.info_global.cal_exx)
963+
{
964+
GlobalC::exx_info.info_ri.real_number ? this->exd->exx_iter_finish(this->kv,
965+
GlobalC::ucell,
966+
*this->p_hamilt,
967+
*this->pelec,
968+
*this->p_chgmix,
969+
this->scf_ene_thr,
970+
iter,
971+
istep,
972+
this->conv_esolver)
973+
: this->exc->exx_iter_finish(this->kv,
974+
GlobalC::ucell,
975+
*this->p_hamilt,
976+
*this->pelec,
977+
*this->p_chgmix,
978+
this->scf_ene_thr,
979+
iter,
980+
istep,
981+
this->conv_esolver);
982+
}
959983
}
960984
#endif
961985

@@ -1087,17 +1111,20 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
10871111

10881112
#ifdef __EXX
10891113
// 5) write Hexx matrix for NSCF (see `out_chg` in docs/advanced/input_files/input-main.md)
1090-
if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.out_chg[0]
1091-
&& istep % PARAM.inp.out_interval == 0) // Peize Lin add if 2022.11.14
1114+
if (PARAM.inp.calculation != "nscf")
10921115
{
1093-
const std::string file_name_exx = PARAM.globalv.global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
1094-
if (GlobalC::exx_info.info_ri.real_number)
1116+
if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.out_chg[0]
1117+
&& istep % PARAM.inp.out_interval == 0) // Peize Lin add if 2022.11.14
10951118
{
1096-
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exd->get_Hexxs());
1097-
}
1098-
else
1099-
{
1100-
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exc->get_Hexxs());
1119+
const std::string file_name_exx = PARAM.globalv.global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
1120+
if (GlobalC::exx_info.info_ri.real_number)
1121+
{
1122+
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exd->get_Hexxs());
1123+
}
1124+
else
1125+
{
1126+
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exc->get_Hexxs());
1127+
}
11011128
}
11021129
}
11031130
#endif

source/module_esolver/lcao_before_scf.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,16 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(const int istep)
210210

211211
// Peize Lin add 2016-12-03
212212
#ifdef __EXX // set xc type before the first cal of xc in pelec->init_scf
213-
if (GlobalC::exx_info.info_ri.real_number)
213+
if (PARAM.inp.calculation != "nscf")
214214
{
215-
this->exd->exx_beforescf(istep, this->kv, *this->p_chgmix, GlobalC::ucell, orb_);
216-
}
217-
else
218-
{
219-
this->exc->exx_beforescf(istep, this->kv, *this->p_chgmix, GlobalC::ucell, orb_);
215+
if (GlobalC::exx_info.info_ri.real_number)
216+
{
217+
this->exd->exx_beforescf(istep, this->kv, *this->p_chgmix, GlobalC::ucell, orb_);
218+
}
219+
else
220+
{
221+
this->exc->exx_beforescf(istep, this->kv, *this->p_chgmix, GlobalC::ucell, orb_);
222+
}
220223
}
221224
#endif // __EXX
222225

source/module_io/read_input.cpp

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,79 +5,52 @@
55
#include <fstream>
66
#include <iostream>
77
#include <sstream>
8+
#include <array>
9+
#include <vector>
10+
#include <cassert>
11+
#include "module_base/formatter.h"
812
#include "module_base/global_file.h"
913
#include "module_base/global_function.h"
1014
#include "module_base/tool_quit.h"
1115
#include "module_base/tool_title.h"
1216
namespace ModuleIO
1317
{
1418

15-
void strtolower(char* sa, char* sb)
19+
std::string longstring(const std::vector<std::string>& words)
1620
{
17-
char c;
18-
int len = strlen(sa);
19-
for (int i = 0; i < len; i++)
20-
{
21-
c = sa[i];
22-
sb[i] = tolower(c);
23-
}
24-
sb[len] = '\0';
21+
return FmtCore::join(" ", words);
2522
}
2623

27-
std::string longstring(const std::vector<std::string>& str_values)
24+
bool assume_as_boolean(const std::string& val)
2825
{
29-
std::string output;
30-
output = "";
31-
const size_t length = str_values.size();
32-
for (int i = 0; i < length; ++i)
33-
{
34-
output += str_values[i];
35-
if (i != length - 1)
36-
{
37-
output += " ";
38-
}
39-
}
40-
return output;
41-
}
26+
const std::string val_ = FmtCore::lower(val);
4227

43-
bool convert_bool(std::string str)
44-
{
45-
for (auto& i: str)
46-
{
47-
i = tolower(i);
48-
}
49-
if (str == "true")
50-
{
51-
return true;
52-
}
53-
else if (str == "false")
54-
{
55-
return false;
56-
}
57-
else if (str == "1")
58-
{
59-
return true;
60-
}
61-
else if (str == "0")
62-
{
63-
return false;
64-
}
65-
else if (str == "t")
28+
const std::array<std::string, 7> t_ = {"true", "1", "t", "yes", "y", "on", ".true."};
29+
const std::array<std::string, 7> f_ = {"false", "0", "f", "no", "n", "off", ".false."};
30+
// This will work because std::array<T, N>::size() is a constexpr function
31+
// Ouch it is of C++17 standard...
32+
// static_assert(t_.size() == f_.size(), "t_ and f_ must have the same lengths");
33+
#ifdef __DEBUG // C++11 can do this
34+
assert(t_.size() == f_.size());
35+
#endif
36+
37+
if (std::find(t_.begin(), t_.end(), val_) != t_.end())
6638
{
6739
return true;
6840
}
69-
else if (str == "f")
41+
else if (std::find(f_.begin(), f_.end(), val_) != f_.end())
7042
{
7143
return false;
7244
}
7345
else
7446
{
75-
std::string warningstr = "Bad boolean parameter ";
76-
warningstr.append(str);
77-
warningstr.append(", please check the input parameters in file INPUT");
78-
ModuleBase::WARNING_QUIT("Input", warningstr);
47+
std::string warnmsg = "Bad boolean parameter ";
48+
warnmsg.append(val);
49+
warnmsg.append(", please check the input parameters in file INPUT");
50+
ModuleBase::WARNING_QUIT("Input", warnmsg);
7951
}
8052
}
53+
8154
std::string to_dir(const std::string& str)
8255
{
8356
std::string str_dir = str;
@@ -216,8 +189,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
216189
ifs.clear();
217190
ifs.seekg(0);
218191

219-
char word[80];
220-
char word1[80];
192+
std::string word, word1;
221193
int ierr = 0;
222194

223195
// ifs >> std::setiosflags(ios::uppercase);
@@ -226,7 +198,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
226198
{
227199
ifs >> word;
228200
ifs.ignore(150, '\n');
229-
if (strcmp(word, "INPUT_PARAMETERS") == 0)
201+
if (word == "INPUT_PARAMETERS")
230202
{
231203
ierr = 1;
232204
break;
@@ -247,10 +219,8 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
247219
while (ifs.good())
248220
{
249221
ifs >> word1;
250-
if (ifs.eof()) {
251-
break;
252-
}
253-
strtolower(word1, word);
222+
if (ifs.eof()) { break; }
223+
word = FmtCore::lower(word1);
254224
auto it = std::find_if(input_lists.begin(),
255225
input_lists.end(),
256226
[&word](const std::pair<std::string, Input_Item>& item) { return item.first == word; });
@@ -311,7 +281,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
311281
Input_Item* resetvalue_item = &(input_item.second);
312282
if (resetvalue_item->reset_value != nullptr) {
313283
resetvalue_item->reset_value(*resetvalue_item, param);
314-
}
284+
}
315285
}
316286
this->set_globalv(param);
317287

@@ -327,7 +297,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
327297
Input_Item* checkvalue_item = &(input_item.second);
328298
if (checkvalue_item->check_value != nullptr) {
329299
checkvalue_item->check_value(*checkvalue_item, param);
330-
}
300+
}
331301
}
332302
}
333303

@@ -505,12 +475,6 @@ void ReadInput::add_item(const Input_Item& item)
505475
}
506476
}
507477

508-
bool find_str(const std::vector<std::string>& strings, const std::string& strToFind)
509-
{
510-
auto it = std::find(strings.begin(), strings.end(), strToFind);
511-
return it != strings.end();
512-
}
513-
514478
std::string nofound_str(std::vector<std::string> init_chgs, const std::string& str)
515479
{
516480
std::string warningstr = "The parameter ";

source/module_io/read_input.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,10 @@ class ReadInput
139139
std::vector<std::function<void(Parameter&)>> bcastfuncs;
140140
};
141141

142-
// convert string to lower case
143-
void strtolower(char* sa, char* sb);
144142
// convert string vector to a long string
145143
std::string longstring(const std::vector<std::string>& str_values);
146144
// convert string to bool
147-
bool convert_bool(std::string str);
148-
// if find a string in a vector of strings
149-
bool find_str(const std::vector<std::string>& strings, const std::string& strToFind);
145+
bool assume_as_boolean(const std::string& val);
150146
// convert to directory format
151147
std::string to_dir(const std::string& str);
152148
// return a warning string if the string is not found in the vector

source/module_io/read_input_item_elec_stru.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ void ReadInput::item_elec_stru()
7676

7777
if (para.input.basis_type == "pw")
7878
{
79-
if (!find_str(pw_solvers, ks_solver))
79+
if (std::find(pw_solvers.begin(), pw_solvers.end(), ks_solver) == pw_solvers.end())
8080
{
8181
const std::string warningstr = "For PW basis: " + nofound_str(pw_solvers, "ks_solver");
8282
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
8383
}
8484
}
8585
else if (para.input.basis_type == "lcao")
8686
{
87-
if (!find_str(lcao_solvers, ks_solver))
87+
if (std::find(lcao_solvers.begin(), lcao_solvers.end(), ks_solver) == lcao_solvers.end())
8888
{
8989
const std::string warningstr = "For LCAO basis: " + nofound_str(lcao_solvers, "ks_solver");
9090
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
@@ -163,7 +163,7 @@ void ReadInput::item_elec_stru()
163163
};
164164
item.check_value = [](const Input_Item& item, const Parameter& para) {
165165
const std::vector<std::string> basis_types = {"pw", "lcao_in_pw", "lcao"};
166-
if (!find_str(basis_types, para.input.basis_type))
166+
if (std::find(basis_types.begin(), basis_types.end(), para.input.basis_type) == basis_types.end())
167167
{
168168
const std::string warningstr = nofound_str(basis_types, "basis_type");
169169
ModuleBase::WARNING_QUIT("ReadInput", warningstr);

0 commit comments

Comments
 (0)