Skip to content

Commit 29ba84c

Browse files
committed
update NEP_CPU
1 parent 8ea1821 commit 29ba84c

File tree

8 files changed

+265
-2
lines changed

8 files changed

+265
-2
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,15 @@ if(DEFINED DeePMD_DIR)
650650
endif()
651651
endif()
652652

653+
option(ENABLE_NEPCPU "Enable NEP calculations on CPU" OFF)
654+
if(ENABLE_NEPCPU)
655+
message(STATUS "NEP support enabled.")
656+
include_directories(/home/mosey/devs/NEP_CPU/src/)
657+
link_directories(/home/mosey/devs/NEP_CPU/src/)
658+
add_compile_definitions(__NEP)
659+
target_link_libraries(${ABACUS_BIN_NAME} /home/mosey/devs/NEP_CPU/src/libnep_cpu.so)
660+
endif()
661+
653662
if(DEFINED TensorFlow_DIR)
654663
find_package(TensorFlow REQUIRED)
655664
include_directories(${TensorFlow_DIR}/include)

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ OBJS_ESOLVER=esolver.o\
266266
esolver_sdft_pw.o\
267267
esolver_lj.o\
268268
esolver_dp.o\
269+
esolver_nep.o\
269270
esolver_of.o\
270271
esolver_of_tddft.o\
271272
esolver_of_tool.o\

source/source_esolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ list(APPEND objects
77
esolver_sdft_pw.cpp
88
esolver_lj.cpp
99
esolver_dp.cpp
10+
esolver_nep.cpp
1011
esolver_of.cpp
1112
esolver_of_tddft.cpp
1213
esolver_of_interface.cpp

source/source_esolver/esolver.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ extern "C"
1818
}
1919
#endif
2020
#include "esolver_dp.h"
21+
#include "esolver_nep.h"
2122
#include "esolver_lj.h"
2223
#include "esolver_of.h"
2324
#include "esolver_of_tddft.h"
@@ -97,6 +98,10 @@ std::string determine_type()
9798
{
9899
esolver_type = "dp_pot";
99100
}
101+
else if (PARAM.inp.esolver_type == "nep")
102+
{
103+
esolver_type = "nep_pot";
104+
}
100105
else if (esolver_type == "none")
101106
{
102107
ModuleBase::WARNING_QUIT("ESolver", "No such esolver_type combined with basis_type");
@@ -338,6 +343,10 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
338343
{
339344
return new ESolver_DP(PARAM.mdp.pot_file);
340345
}
346+
else if (esolver_type == "nep_pot")
347+
{
348+
return new ESolver_NEP(PARAM.mdp.pot_file);
349+
}
341350
throw std::invalid_argument("esolver_type = " + std::string(esolver_type) + ". Wrong in " + std::string(__FILE__)
342351
+ " line " + std::to_string(__LINE__));
343352
}
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#include "esolver_nep.h"
2+
#include "source_base/parallel_common.h"
3+
#include "source_base/timer.h"
4+
#include "source_io/output_log.h"
5+
#include "source_io/cif_io.h"
6+
#include <numeric>
7+
#include <unordered_map>
8+
9+
using namespace ModuleESolver;
10+
11+
ESolver_NEP::ESolver_NEP(const std::string& pot_file): nep(pot_file)
12+
{
13+
classname = "ESolver_NEP";
14+
nep_file = pot_file;
15+
}
16+
17+
void ESolver_NEP::before_all_runners(UnitCell& ucell, const Input_para& inp)
18+
{
19+
nep_potential = 0.0;
20+
nep_force.create(ucell.nat, 3);
21+
nep_virial.create(3, 3);
22+
_e.resize(ucell.nat);
23+
_f.resize(3 * ucell.nat);
24+
_v.resize(9 * ucell.nat);
25+
26+
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU.cif",
27+
ucell,
28+
"# Generated by ABACUS ModuleIO::CifParser",
29+
"data_?");
30+
31+
atype.resize(ucell.nat);
32+
33+
#ifdef __NEP
34+
/// determine the type map from STRU to DP model
35+
type_map(ucell);
36+
#endif
37+
}
38+
39+
void ESolver_NEP::runner(UnitCell& ucell, const int istep)
40+
{
41+
ModuleBase::TITLE("ESolver_NEP", "runner");
42+
ModuleBase::timer::tick("ESolver_NEP", "runner");
43+
44+
// note that NEP are column major, thus a transpose is needed
45+
// cell
46+
std::vector<double> cell(9, 0.0);
47+
cell[0] = ucell.latvec.e11 * ucell.lat0_angstrom;
48+
cell[1] = ucell.latvec.e21 * ucell.lat0_angstrom;
49+
cell[2] = ucell.latvec.e31 * ucell.lat0_angstrom;
50+
cell[3] = ucell.latvec.e12 * ucell.lat0_angstrom;
51+
cell[4] = ucell.latvec.e22 * ucell.lat0_angstrom;
52+
cell[5] = ucell.latvec.e32 * ucell.lat0_angstrom;
53+
cell[6] = ucell.latvec.e13 * ucell.lat0_angstrom;
54+
cell[7] = ucell.latvec.e23 * ucell.lat0_angstrom;
55+
cell[8] = ucell.latvec.e33 * ucell.lat0_angstrom;
56+
57+
// coord
58+
std::vector<double> coord(3 * ucell.nat, 0.0);
59+
int iat = 0;
60+
const int nat = ucell.nat;
61+
for (int it = 0; it < ucell.ntype; ++it)
62+
{
63+
for (int ia = 0; ia < ucell.atoms[it].na; ++ia)
64+
{
65+
coord[iat] = ucell.atoms[it].tau[ia].x * ucell.lat0_angstrom;
66+
coord[iat + nat] = ucell.atoms[it].tau[ia].y * ucell.lat0_angstrom;
67+
coord[iat + 2 * nat] = ucell.atoms[it].tau[ia].z * ucell.lat0_angstrom;
68+
iat++;
69+
}
70+
}
71+
assert(ucell.nat == iat);
72+
73+
#ifdef __NEP
74+
nep_potential = 0.0;
75+
nep_force.zero_out();
76+
nep_virial.zero_out();
77+
78+
nep.compute(atype, cell, coord, _e, _f, _v);
79+
80+
// unit conversion
81+
const double fact_e = 1.0 / ModuleBase::Ry_to_eV;
82+
const double fact_f = 1.0 / (ModuleBase::Ry_to_eV * ModuleBase::ANGSTROM_AU);
83+
const double fact_v = 1.0 / (ucell.omega * ModuleBase::Ry_to_eV);
84+
85+
86+
// potential energy
87+
nep_potential = fact_e * std::accumulate(_e.begin(), _e.end(), 0.0) ;
88+
GlobalV::ofs_running << " #TOTAL ENERGY# " << std::setprecision(11) << nep_potential * ModuleBase::Ry_to_eV << " eV"
89+
<< std::endl;
90+
91+
// forces
92+
for (int i = 0; i < nat; ++i)
93+
{
94+
nep_force(i, 0) = _f[i] * fact_f;
95+
nep_force(i, 1) = _f[i + nat] * fact_f;
96+
nep_force(i, 2) = _f[i + 2 * nat] * fact_f;
97+
}
98+
99+
// get the total virial by summing over all atomic contributions
100+
std::vector<double> v_sum(9, 0.0);
101+
for (int j = 0; j < 9; ++j)
102+
{
103+
for (int i = 0; i < nat; ++i)
104+
{
105+
int index = j * nat + i;
106+
v_sum[j] += _v[index];
107+
}
108+
}
109+
110+
// transform to stress tensor
111+
for (int i = 0; i < 3; ++i)
112+
{
113+
for (int j = 0; j < 3; ++j)
114+
{
115+
nep_virial(i, j) = v_sum[3 * i + j] * fact_v;
116+
}
117+
}
118+
#else
119+
ModuleBase::WARNING_QUIT("ESolver_NEP", "Please recompile with -D__NEP");
120+
#endif
121+
ModuleBase::timer::tick("ESolver_NEP", "runner");
122+
}
123+
124+
double ESolver_NEP::cal_energy()
125+
{
126+
return nep_potential;
127+
}
128+
129+
void ESolver_NEP::cal_force(UnitCell& ucell, ModuleBase::matrix& force)
130+
{
131+
force = nep_force;
132+
ModuleIO::print_force(GlobalV::ofs_running, ucell, "TOTAL-FORCE (eV/Angstrom)", force, false);
133+
}
134+
135+
void ESolver_NEP::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
136+
{
137+
stress = nep_virial;
138+
ModuleIO::print_stress("TOTAL-STRESS", stress, true, false, GlobalV::ofs_running);
139+
140+
// external stress
141+
double unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI, 3) * 1.0e-8;
142+
double external_stress[3] = {PARAM.inp.press1, PARAM.inp.press2, PARAM.inp.press3};
143+
for (int i = 0; i < 3; i++)
144+
{
145+
stress(i, i) -= external_stress[i] / unit_transform;
146+
}
147+
}
148+
149+
void ESolver_NEP::after_all_runners(UnitCell& ucell)
150+
{
151+
GlobalV::ofs_running << "\n --------------------------------------------" << std::endl;
152+
GlobalV::ofs_running << std::setprecision(16);
153+
GlobalV::ofs_running << " !FINAL_ETOT_IS " << nep_potential * ModuleBase::Ry_to_eV << " eV" << std::endl;
154+
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
155+
}
156+
157+
#ifdef __NEP
158+
void ESolver_NEP::type_map(const UnitCell& ucell)
159+
{
160+
std::unordered_map<std::string, int> label;
161+
std::string temp;
162+
for (int i = 0; i < nep.element_list.size(); ++i)
163+
{
164+
label[nep.element_list[i]] = i;
165+
}
166+
167+
std::cout << "\n Element list of model file " << nep_file << " " << std::endl;
168+
std::cout << " ----------------------------------------------------------------";
169+
int count = 0;
170+
for (auto it = label.begin(); it != label.end(); ++it)
171+
{
172+
if (count % 5 == 0)
173+
{
174+
std::cout << std::endl;
175+
std::cout << " ";
176+
}
177+
count++;
178+
temp = it->first + ": " + std::to_string(it->second);
179+
std::cout << std::left << std::setw(10) << temp;
180+
}
181+
std::cout << "\n -----------------------------------------------------------------" << std::endl;
182+
183+
int iat = 0;
184+
for (int it = 0; it < ucell.ntype; ++it)
185+
{
186+
for (int ia = 0; ia < ucell.atoms[it].na; ++ia)
187+
{
188+
if (label.find(ucell.atoms[it].label) == label.end())
189+
{
190+
ModuleBase::WARNING_QUIT("ESolver_NEP",
191+
"The label " + ucell.atoms[it].label + " is not found in the type map.");
192+
}
193+
atype[iat] = label[ucell.atoms[it].label];
194+
iat++;
195+
}
196+
}
197+
assert(ucell.nat == iat);
198+
}
199+
#endif
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef ESOLVER_NEP_H
2+
#define ESOLVER_NEP_H
3+
4+
#include "esolver.h"
5+
#ifdef __NEP
6+
#include "nep.h"
7+
#endif
8+
#include <vector>
9+
#include <string>
10+
11+
namespace ModuleESolver
12+
{
13+
14+
class ESolver_NEP : public ESolver
15+
{
16+
public:
17+
ESolver_NEP(const std::string& pot_file);
18+
19+
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
20+
void runner(UnitCell& ucell, const int istep) override;
21+
double cal_energy() override;
22+
void cal_force(UnitCell& ucell, ModuleBase::matrix& force) override;
23+
void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) override;
24+
void after_all_runners(UnitCell& ucell) override;
25+
26+
private:
27+
void type_map(const UnitCell& ucell);
28+
#ifdef __NEP
29+
NEP3 nep;
30+
#endif
31+
std::string nep_file;
32+
std::vector<int> atype = {};
33+
double nep_potential;
34+
ModuleBase::matrix nep_force;
35+
ModuleBase::matrix nep_virial;
36+
std::vector<double> _e;
37+
std::vector<double> _f;
38+
std::vector<double> _v;
39+
};
40+
41+
} // namespace ModuleESolver
42+
43+
#endif

source/source_io/read_input_item_system.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ void ReadInput::item_system()
111111
item.annotation = "the energy solver: ksdft, sdft, ofdft, tdofdft, tddft, lj, dp, ks-lr, lr";
112112
read_sync_string(input.esolver_type);
113113
item.check_value = [](const Input_Item& item, const Parameter& para) {
114-
const std::vector<std::string> esolver_types = { "ksdft", "sdft", "ofdft", "tdofdft", "tddft", "lj", "dp", "lr", "ks-lr" };
114+
const std::vector<std::string> esolver_types = { "ksdft", "sdft", "ofdft", "tdofdft", "tddft", "lj", "dp", "nep", "lr", "ks-lr" };
115115
if (std::find(esolver_types.begin(), esolver_types.end(), para.input.esolver_type) == esolver_types.end())
116116
{
117117
const std::string warningstr = nofound_str(esolver_types, "esolver_type");
118118
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
119119
}
120-
if (para.input.esolver_type == "dp")
120+
if (para.input.esolver_type == "dp" || para.input.esolver_type == "nep")
121121
{
122122
if (access(para.input.mdp.pot_file.c_str(), 0) == -1)
123123
{

toolchain/build_abacus_intel.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ cmake -B $BUILD_DIR -DCMAKE_INSTALL_PREFIX=$PREFIX \
4949
-DRapidJSON_DIR=$RAPIDJSON \
5050
-DLIBRI_DIR=$LIBRI \
5151
-DLIBCOMM_DIR=$LIBCOMM \
52+
-DENABLE_NEPCPU=ON
5253
# -DENABLE_MLALGO=1 \
5354
# -DTorch_DIR=$LIBTORCH \
5455
# -Dlibnpy_INCLUDE_DIR=$LIBNPY \

0 commit comments

Comments
 (0)