Skip to content

Commit 2825851

Browse files
authored
Feature: support the default as the value of dft_functional when initialize vdw (#5949)
* Feature: support the `default` as the value of `dft_functional` when initialize vdw * Refactor a littble bit
1 parent b3bc912 commit 2825851

File tree

6 files changed

+107
-72
lines changed

6 files changed

+107
-72
lines changed

source/module_hamilt_general/module_vdw/test/vdw_test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class vdwd3Test: public testing::Test
377377
TEST_F(vdwd3Test, D30Default)
378378
{
379379
vdw::Vdwd3 vdwd3_test(ucell);
380-
vdwd3_test.parameter().initial_parameters(input);
380+
vdwd3_test.parameter().initial_parameters("pbe", input);
381381

382382
EXPECT_EQ(vdwd3_test.parameter().s6(), 1.0);
383383
EXPECT_EQ(vdwd3_test.parameter().s18(), 0.7875);
@@ -396,7 +396,8 @@ TEST_F(vdwd3Test, D30UnitA)
396396
input.vdw_cn_thr_unit = "A";
397397
vdw::Vdwd3 vdwd3_test(ucell);
398398

399-
vdwd3_test.parameter().initial_parameters(input);
399+
const std::string xc = "pbe";
400+
vdwd3_test.parameter().initial_parameters(xc, input);
400401

401402
EXPECT_EQ(vdwd3_test.parameter().rthr2(), std::pow(95/ModuleBase::BOHR_TO_A, 2));
402403
EXPECT_EQ(vdwd3_test.parameter().cn_thr2(), std::pow(40/ModuleBase::BOHR_TO_A, 2));
@@ -407,7 +408,8 @@ TEST_F(vdwd3Test, D30Period)
407408
input.vdw_cutoff_type = "period";
408409
vdw::Vdwd3 vdwd3_test(ucell);
409410

410-
vdwd3_test.parameter().initial_parameters(input);
411+
const std::string xc = "pbe";
412+
vdwd3_test.parameter().initial_parameters(xc, input);
411413
vdwd3_test.init();
412414
std::vector<int> rep_vdw_ref = {input.vdw_cutoff_period.x, input.vdw_cutoff_period.y, input.vdw_cutoff_period.z};
413415

source/module_hamilt_general/module_vdw/vdw.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,41 @@
1+
#include <algorithm>
2+
#include <cassert>
13

24
#include "vdw.h"
35
#include "vdwd2.h"
46
#include "vdwd3.h"
7+
#include "module_base/tool_quit.h"
8+
9+
std::string parse_xcname(const std::string &xc_input,
10+
const std::vector<std::string> &xc_psp)
11+
{
12+
if (xc_input != "default")
13+
{
14+
return xc_input;
15+
}
16+
17+
if (xc_psp.size() <= 0)
18+
{
19+
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::parse_xcname",
20+
"XC name automatic inference failed: no pseudopotential files are found");
21+
}
22+
std::vector<std::string> xc_psp_uniq = xc_psp;
23+
std::sort(xc_psp_uniq.begin(), xc_psp_uniq.end());
24+
auto last = std::unique(xc_psp_uniq.begin(), xc_psp_uniq.end());
25+
xc_psp_uniq.erase(last, xc_psp_uniq.end());
26+
27+
if (xc_psp_uniq.size() > 1)
28+
{
29+
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::parse_xcname",
30+
"XC name automatic inference failed: inconsistency in XC names is found"
31+
" in the pseudopotential files");
32+
}
33+
const std::string xc = xc_psp_uniq[0];
34+
std::cout << " ***WARNING*** ModuleHamiltGeneral::ModuleVDW::parse_xcname: "
35+
<< "XC name is automatically inferred from pseudopotential as `"
36+
<< xc << "`" << std::endl;
37+
return xc;
38+
}
539

640
namespace vdw
741
{
@@ -24,8 +58,13 @@ std::unique_ptr<Vdw> make_vdw(const UnitCell &ucell,
2458
}
2559
else if (input.vdw_method == "d3_0" || input.vdw_method == "d3_bj")
2660
{
61+
std::vector<std::string> xc_psp(ucell.ntype);
62+
for (int it = 0; it < ucell.ntype; it++)
63+
{
64+
xc_psp[it] = ucell.atoms[it].ncpp.xc_func;
65+
}
2766
std::unique_ptr<Vdwd3> vdw_ptr = make_unique<Vdwd3>(ucell);
28-
vdw_ptr->parameter().initial_parameters(input, plog);
67+
vdw_ptr->parameter().initial_parameters(parse_xcname(input.dft_functional, xc_psp), input, plog);
2968
return vdw_ptr;
3069
}
3170
else if (input.vdw_method != "none")

source/module_hamilt_general/module_vdw/vdw.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
namespace vdw
1212
{
1313

14-
template<typename T, typename... Args>
15-
std::unique_ptr<T> make_unique(Args &&... args) {
16-
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
14+
template<typename T, typename... Args>
15+
std::unique_ptr<T> make_unique(Args &&... args) {
16+
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
17+
1718
}
1819

1920
class Vdw

source/module_hamilt_general/module_vdw/vdwd3_autoset_xcparam.cpp

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
// DFT-D3(BJ)
7272
const std::map<std::string, std::vector<double>> bj = {
73+
{"__default__", {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 14.0, 0.0}},
7374
{"bp", {1.0, 0.3946, 0.3946, 3.2822, 4.8516, 4.8516, 1.0, 14.0, 0.0}},
7475
{"blyp", {1.0, 0.4298, 0.4298, 2.6996, 4.2359, 4.2359, 1.0, 14.0, 0.0}},
7576
{"revpbe", {1.0, 0.5238, 0.5238, 2.355, 3.5016, 3.5016, 1.0, 14.0, 0.0}},
@@ -231,6 +232,7 @@ const std::map<std::string, std::vector<double>> bj = {
231232
};
232233
// DFT-D3(0)
233234
const std::map<std::string, std::vector<double>> zero = {
235+
{"__default__", {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 14.0, 0.0}},
234236
{"slaterdirac", {1.0, 0.999, 0.999, -1.957, 0.697, 0.697, 1.0, 14.0, 0.0}},
235237
{"bp", {1.0, 1.139, 1.139, 1.683, 1.0, 1.0, 1.0, 14.0, 0.0}},
236238
{"blyp", {1.0, 1.094, 1.094, 1.682, 1.0, 1.0, 1.0, 14.0, 0.0}},
@@ -318,6 +320,7 @@ const std::map<std::string, std::vector<double>> zero = {
318320
};
319321
// DFT-D3M(BJ): not implemented for beta
320322
const std::map<std::string, std::vector<double>> bjm = {
323+
{"__default__", {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 14.0, 0.0}},
321324
{"bp", {1.0, 0.82185, 0.82185, 3.140281, 2.728151, 2.728151, 1.0, 14.0, 0.0}},
322325
{"blyp", {1.0, 0.448486, 0.448486, 1.875007, 3.610679, 3.610679, 1.0, 14.0, 0.0}},
323326
{"b97_d", {1.0, 0.240184, 0.240184, 1.206988, 3.864426, 3.864426, 1.0, 14.0, 0.0}},
@@ -329,6 +332,7 @@ const std::map<std::string, std::vector<double>> bjm = {
329332
};
330333
// DFT-D3M(0): not implemented for beta
331334
const std::map<std::string, std::vector<double>> zerom = {
335+
{"__default__", {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 14.0, 0.0}},
332336
{"bp", {1.0, 1.23346, 1.23346, 1.945174, 1.0, 1.0, 1.0, 14.0, 0.0}},
333337
{"blyp", {1.0, 1.279637, 1.279637, 1.841686, 1.0, 1.0, 1.0, 14.0, 0.01437}},
334338
{"b97_d", {1.0, 1.151808, 1.151808, 1.020078, 1.0, 1.0, 1.0, 14.0, 0.035964}},
@@ -340,6 +344,8 @@ const std::map<std::string, std::vector<double>> zerom = {
340344
};
341345
// DFT-D3(OptimizedPower)
342346
const std::map<std::string, std::vector<double>> op = {
347+
// {'s6', 'rs6', 'a1', 's8', 'rs8', 'a2', 's9', 'alp', 'bet'}
348+
{"__default__", {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 14.0, 0.0}},
343349
{"blyp", {1.0, 0.425, 0.425, 1.31867, 3.5, 3.5, 1.0, 14.0, 2.0}},
344350
{"revpbe", {1.0, 0.6, 0.6, 1.44765, 2.5, 2.5, 1.0, 14.0, 0.0}},
345351
{"b97_d", {1.0, 0.6, 0.6, 1.46861, 2.5, 2.5, 1.0, 14.0, 0.0}},
@@ -356,7 +362,18 @@ const std::map<std::string, std::vector<double>> op = {
356362
{"ms2h", {1.0, 0.65, 0.65, 1.69464, 4.75, 4.75, 1.0, 14.0, 0.0}},
357363
};
358364

359-
365+
std::vector<double> _search_impl(const std::string& xc,
366+
const std::map<std::string, std::vector<double>>& dict)
367+
{
368+
if (dict.find(xc) != dict.end())
369+
{
370+
return dict.at(xc);
371+
}
372+
else
373+
{
374+
return std::vector<double>();
375+
}
376+
}
360377
// 's6', 'rs6', 'a1', 's8', 'rs8', 'a2', 's9', 'alp', 'bet'
361378
/**
362379
* @brief Get the dftd3 params object.
@@ -368,76 +385,49 @@ const std::map<std::string, std::vector<double>> op = {
368385
* @param param the dftd3 parameters, ALL_KEYS = {'s6', 'rs6', 'a1', 's8', 'rs8', 'a2', 's9', 'alp', 'bet'}
369386
*/
370387
void _search(const std::string& xc,
371-
const std::string& method,
372-
std::vector<double>& param)
388+
const std::string& method,
389+
std::vector<double>& param)
373390
{
374391
const std::string xc_lowercase = FmtCore::lower(xc);
375392
const std::vector<std::string> allowed_ = { "bj", "zero", "bjm", "zerom", "op" };
376-
assert(std::find(allowed_.begin(), allowed_.end(), method) != allowed_.end());
377-
if (method == "op")
378-
{
379-
if (op.find(xc_lowercase) != op.end())
380-
{
381-
param = op.at(xc_lowercase);
382-
}
383-
else
384-
{
385-
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
386-
"XC (`" + xc + "`)'s DFT-D3(OP) parameters not found");
387-
}
388-
}
389-
else if (method == "bjm")
393+
const int i = std::find(allowed_.begin(), allowed_.end(), method) - allowed_.begin();
394+
std::map<std::string, std::vector<double>> const * pdict = nullptr;
395+
switch (i)
390396
{
391-
if (bjm.find(xc_lowercase) != bjm.end())
392-
{
393-
param = bjm.at(xc_lowercase);
394-
}
395-
else
396-
{
397-
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
398-
"XC (`" + xc + "`)'s DFT-D3M(BJ) parameters not found");
399-
}
397+
case 0:
398+
pdict = &bj;
399+
break;
400+
case 1:
401+
pdict = &zero;
402+
break;
403+
case 2:
404+
pdict = &bjm;
405+
break;
406+
case 3:
407+
pdict = &zerom;
408+
break;
409+
case 4:
410+
pdict = &op;
411+
break;
412+
default:
413+
pdict = nullptr;
414+
break;
400415
}
401-
else if (method == "bj")
416+
if (pdict == nullptr)
402417
{
403-
if (bj.find(xc_lowercase) != bj.end())
404-
{
405-
param = bj.at(xc_lowercase);
406-
}
407-
else
408-
{
409-
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
410-
"XC (`" + xc + "`)'s DFT-D3(BJ) parameters not found");
411-
}
412-
}
413-
else if (method == "zerom")
414-
{
415-
if (zerom.find(xc_lowercase) != zerom.end())
416-
{
417-
param = zerom.at(xc_lowercase);
418-
}
419-
else
420-
{
421-
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
422-
"XC (`" + xc + "`)'s DFT-D3M(0) parameters not found");
423-
}
424-
}
425-
else if (method == "zero")
426-
{
427-
if (zero.find(xc_lowercase) != zero.end())
428-
{
429-
param = zero.at(xc_lowercase);
430-
}
431-
else
432-
{
433-
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
434-
"XC (`" + xc + "`)'s DFT-D3(0) parameters not found");
435-
}
418+
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
419+
"Unknown DFT-D3 method: " + method);
436420
}
437-
else // should not reach here
421+
param = _search_impl(xc_lowercase, *pdict);
422+
if (param.empty())
438423
{
439424
ModuleBase::WARNING_QUIT("ModuleHamiltGeneral::ModuleVDW::DFTD3::_search",
440-
"Unknown DFT-D3 method: " + method);
425+
"XC (`" + xc + "`)'s DFT-D3(" + method + ") parameters not found");
426+
// is it meaningful to return a so-called default value?
427+
std::cout << " ***WARNING*** "
428+
<< "XC (`" << xc << "`)'s DFT-D3(" << method << ") parameters not found, "
429+
<< "using default values. Please use at your own risk!" << std::endl;
430+
param = _search_impl("__default__", *pdict);
441431
}
442432
}
443433

source/module_hamilt_general/module_vdw/vdwd3_parameters.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
namespace vdw
1111
{
1212

13-
void Vdwd3Parameters::initial_parameters(const Input_para &input, std::ofstream* plog)
13+
void Vdwd3Parameters::initial_parameters(const std::string& xc,
14+
const Input_para& input,
15+
std::ofstream* plog)
1416
{
1517
// initialize the dftd3 parameters
1618
mxc_.resize(max_elem_, 1);
@@ -23,7 +25,7 @@ void Vdwd3Parameters::initial_parameters(const Input_para &input, std::ofstream*
2325
5,
2426
std::vector<std::vector<double>>(max_elem_, std::vector<double>(max_elem_, 0.0)))));
2527

26-
_vdwd3_autoset_xcparam(input.dft_functional, input.vdw_method,
28+
_vdwd3_autoset_xcparam(xc, input.vdw_method,
2729
input.vdw_s6, input.vdw_s8, input.vdw_a1, input.vdw_a2,
2830
s6_, s18_, rs6_, rs18_, /* rs6: a1, rs18: a2 */
2931
plog);

source/module_hamilt_general/module_vdw/vdwd3_parameters.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class Vdwd3Parameters : public VdwParameters
2727
* @param input Parameter instance
2828
* @param plog optional, for logging the parameter setting process
2929
*/
30-
void initial_parameters(const Input_para &input,
30+
void initial_parameters(const std::string& xc,
31+
const Input_para& input,
3132
std::ofstream* plog = nullptr); // for logging the parameter autoset
3233

3334
inline const std::string &version() const { return version_; }

0 commit comments

Comments
 (0)