Skip to content

Commit 3c8ac0a

Browse files
authored
Fix: Solve mpi bug for DeePKS. (#5886)
* Fix mpi bug for DeePKS and add some timer ticks. * clang-format change. * Fix bug in DeePKS test. * Update deepks_basic.cpp
1 parent 74284f9 commit 3c8ac0a

File tree

15 files changed

+179
-117
lines changed

15 files changed

+179
-117
lines changed

source/module_esolver/lcao_before_scf.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,13 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
227227

228228
if (PARAM.inp.deepks_out_unittest)
229229
{
230-
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha);
230+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force,
231+
ucell,
232+
orb_,
233+
this->gd,
234+
pv,
235+
this->ld.phialpha,
236+
GlobalV::MY_RANK);
231237
}
232238
}
233239
#endif

source/module_esolver/lcao_others.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,13 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
233233

234234
if (PARAM.inp.deepks_out_unittest)
235235
{
236-
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha);
236+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force,
237+
ucell,
238+
orb_,
239+
this->gd,
240+
pv,
241+
this->ld.phialpha,
242+
GlobalV::MY_RANK);
237243
}
238244
}
239245
#endif

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
5353
std::vector<int> na)
5454
{
5555
ModuleBase::TITLE("LCAO_Deepks", "init");
56+
ModuleBase::timer::tick("LCAO_Deepks", "init");
5657

5758
GlobalV::ofs_running << " Initialize the descriptor index for DeePKS (lcao line)" << std::endl;
5859

@@ -123,6 +124,7 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
123124

124125
this->pv = &pv_in;
125126

127+
ModuleBase::timer::tick("LCAO_Deepks", "init");
126128
return;
127129
}
128130

@@ -169,6 +171,7 @@ void LCAO_Deepks::init_index(const int ntype,
169171
void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
170172
{
171173
ModuleBase::TITLE("LCAO_Deepks", "allocate_V_delta");
174+
ModuleBase::timer::tick("LCAO_Deepks", "allocate_V_delta");
172175

173176
// initialize the H matrix H_V_delta
174177
if (PARAM.globalv.gamma_only_local)
@@ -205,6 +208,7 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
205208
ModuleBase::GlobalFunc::ZEROS(this->gedm[inl], pdm_size);
206209
}
207210

211+
ModuleBase::timer::tick("LCAO_Deepks", "allocate_V_delta");
208212
return;
209213
}
210214

source/module_hamilt_lcao/module_deepks/LCAO_deepks_io.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,11 @@ void LCAO_deepks_io::save_matrix2npy(const std::string& file_name,
261261
template <typename T>
262262
void LCAO_deepks_io::save_tensor2npy(const std::string& file_name, const torch::Tensor& tensor, const int rank)
263263
{
264+
ModuleBase::TITLE("LCAO_deepks_io", "save_tensor2npy");
264265
if (rank != 0)
265266
{
266267
return;
267268
}
268-
ModuleBase::TITLE("LCAO_deepks_io", "save_tensor2npy");
269269
const int dim = tensor.dim();
270270
std::vector<long unsigned> shape(dim);
271271
for (int i = 0; i < dim; i++)

source/module_hamilt_lcao/module_deepks/deepks_basic.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#ifdef __DEEPKS
66
#include "deepks_basic.h"
7+
8+
#include "module_base/timer.h"
79
#include "module_parameter/parameter.h"
810

911
// d(Descriptor) / d(projected density matrix)
@@ -15,6 +17,7 @@ void DeePKS_domain::cal_gevdm(const int nat,
1517
std::vector<torch::Tensor>& gevdm)
1618
{
1719
ModuleBase::TITLE("DeePKS_domain", "cal_gevdm");
20+
ModuleBase::timer::tick("DeePKS_domain", "cal_gevdm");
1821
// cal gevdm(d(EigenValue(D))/dD)
1922
int nlmax = inlmax / nat;
2023
for (int nl = 0; nl < nlmax; ++nl)
@@ -48,12 +51,14 @@ void DeePKS_domain::cal_gevdm(const int nat,
4851
gevdm.push_back(avmm);
4952
}
5053
assert(gevdm.size() == nlmax);
54+
ModuleBase::timer::tick("DeePKS_domain", "cal_gevdm");
5155
return;
5256
}
5357

5458
void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script::Module& model)
5559
{
5660
ModuleBase::TITLE("DeePKS_domain", "load_model");
61+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
5762

5863
try
5964
{
@@ -62,8 +67,10 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script
6267
catch (const c10::Error& e)
6368
{
6469
std::cerr << "error loading the model" << std::endl;
70+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
6571
return;
6672
}
73+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
6774
return;
6875
}
6976

@@ -122,16 +129,17 @@ inline void generate_py_files(const int lmaxd, const int nmaxd, const std::strin
122129
}
123130

124131
void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
125-
const int lmaxd,
126-
const int nmaxd,
127-
const int inlmax,
128-
const int des_per_atom,
129-
const int* inl_l,
130-
const std::vector<torch::Tensor>& descriptor,
131-
double** gedm,
132-
double& E_delta)
132+
const int lmaxd,
133+
const int nmaxd,
134+
const int inlmax,
135+
const int des_per_atom,
136+
const int* inl_l,
137+
const std::vector<torch::Tensor>& descriptor,
138+
double** gedm,
139+
double& E_delta)
133140
{
134141
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm_equiv");
142+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm_equiv");
135143

136144
LCAO_deepks_io::save_npy_d(nat,
137145
des_per_atom,
@@ -157,28 +165,32 @@ void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
157165

158166
std::string cmd = "rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy";
159167
std::system(cmd.c_str());
168+
169+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm_equiv");
170+
return;
160171
}
161172

162173
// obtain from the machine learning model dE_delta/dDescriptor
163174
// E_delta is also calculated here
164175
void DeePKS_domain::cal_edelta_gedm(const int nat,
165-
const int lmaxd,
166-
const int nmaxd,
167-
const int inlmax,
168-
const int des_per_atom,
169-
const int* inl_l,
170-
const std::vector<torch::Tensor>& descriptor,
171-
const std::vector<torch::Tensor>& pdm,
172-
torch::jit::script::Module& model_deepks,
173-
double** gedm,
174-
double& E_delta)
176+
const int lmaxd,
177+
const int nmaxd,
178+
const int inlmax,
179+
const int des_per_atom,
180+
const int* inl_l,
181+
const std::vector<torch::Tensor>& descriptor,
182+
const std::vector<torch::Tensor>& pdm,
183+
torch::jit::script::Module& model_deepks,
184+
double** gedm,
185+
double& E_delta)
175186
{
176187
if (PARAM.inp.deepks_equiv)
177188
{
178189
DeePKS_domain::cal_edelta_gedm_equiv(nat, lmaxd, nmaxd, inlmax, des_per_atom, inl_l, descriptor, gedm, E_delta);
179190
return;
180191
}
181192
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm");
193+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm");
182194

183195
// forward
184196
std::vector<torch::jit::IValue> inputs;
@@ -213,6 +225,7 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
213225
}
214226
}
215227
}
228+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm");
216229
return;
217230
}
218231

source/module_hamilt_lcao/module_deepks/deepks_force.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
2525
ModuleBase::matrix& svnl_dalpha)
2626
{
2727
ModuleBase::TITLE("DeePKS_domain", "cal_f_delta");
28-
28+
ModuleBase::timer::tick("DeePKS_domain", "cal_f_delta");
2929
f_delta.zero_out();
3030

3131
const double Rcut_Alpha = orb.Alpha[0].getRcut();
3232
const int lmaxd = orb.get_lmax_d();
33-
const int nrow = pv.nrow;
3433

3534
for (int T0 = 0; T0 < ucell.ntype; T0++)
3635
{
@@ -146,6 +145,22 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
146145
}
147146
}
148147

148+
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
149+
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
150+
if (overlap_1 == nullptr || overlap_2 == nullptr)
151+
{
152+
continue;
153+
}
154+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
155+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
156+
for (int i = 0; i < 3; ++i)
157+
{
158+
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
159+
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
160+
}
161+
162+
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
163+
149164
const double* dm_current = dm_pair.get_pointer();
150165

151166
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
@@ -155,18 +170,6 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
155170
double nlm[3] = {0, 0, 0};
156171
double nlm_t[3] = {0, 0, 0}; // for stress
157172

158-
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
159-
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
160-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
161-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
162-
for (int i = 0; i < 3; ++i)
163-
{
164-
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
165-
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
166-
}
167-
168-
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
169-
170173
if (!PARAM.inp.deepks_equiv)
171174
{
172175
int ib = 0;
@@ -310,7 +313,7 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
310313
}
311314
}
312315
}
313-
316+
ModuleBase::timer::tick("DeePKS_domain", "cal_f_delta");
314317
return;
315318
}
316319

source/module_hamilt_lcao/module_deepks/deepks_fpre.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,27 @@ void DeePKS_domain::cal_gdmx(const int lmaxd,
136136

137137
dm_current = dm_pair.get_pointer();
138138

139-
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
139+
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
140+
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
141+
if (overlap_1 == nullptr || overlap_2 == nullptr)
140142
{
141-
for (int iw2 = 0; iw2 < col_indexes.size(); ++iw2)
142-
{
143-
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
144-
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
145-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
146-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
143+
continue;
144+
}
145+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
146+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
147147

148-
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
148+
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
149149

150-
for (int i = 0; i < 3; ++i)
151-
{
152-
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
153-
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
154-
}
150+
for (int i = 0; i < 3; ++i)
151+
{
152+
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
153+
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
154+
}
155155

156+
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
157+
{
158+
for (int iw2 = 0; iw2 < col_indexes.size(); ++iw2)
159+
{
156160
int ib = 0;
157161
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
158162
{
@@ -266,7 +270,7 @@ void DeePKS_domain::cal_gvx(const int nat,
266270
torch::Tensor& gvx)
267271
{
268272
ModuleBase::TITLE("DeePKS_domain", "cal_gvx");
269-
273+
ModuleBase::timer::tick("DeePKS_domain", "cal_gvx");
270274
// gdmr : nat(derivative) * 3 * inl(projector) * nm * nm
271275
std::vector<torch::Tensor> gdmr;
272276
auto accessor = gdmx.accessor<double, 5>();
@@ -332,7 +336,7 @@ void DeePKS_domain::cal_gvx(const int nat,
332336
assert(gvx.size(2) == nat);
333337
assert(gvx.size(3) == des_per_atom);
334338
}
335-
339+
ModuleBase::timer::tick("DeePKS_domain", "cal_gvx");
336340
return;
337341
}
338342

source/module_hamilt_lcao/module_deepks/deepks_orbital.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void DeePKS_domain::cal_o_delta(const std::vector<TH>& dm_hl,
1515
const int nks)
1616
{
1717
ModuleBase::TITLE("DeePKS_domain", "cal_o_delta");
18+
ModuleBase::timer::tick("DeePKS_domain", "cal_o_delta");
1819

1920
for (int ik = 0; ik < nks; ik++)
2021
{
@@ -64,6 +65,7 @@ void DeePKS_domain::cal_o_delta(const std::vector<TH>& dm_hl,
6465
o_delta(ik, 0) = o_delta_tmp.real();
6566
}
6667
}
68+
ModuleBase::timer::tick("DeePKS_domain", "cal_o_delta");
6769
return;
6870
}
6971

source/module_hamilt_lcao/module_deepks/deepks_orbpre.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,9 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
9494

9595
ModuleBase::Vector3<int> dR1(GridD.getBox(ad1).x, GridD.getBox(ad1).y, GridD.getBox(ad1).z);
9696

97-
if constexpr (std::is_same<TK, std::complex<double>>::value)
97+
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
9898
{
99-
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
100-
{
101-
continue;
102-
}
99+
continue;
103100
}
104101

105102
auto row_indexes = pv.get_indexes_row(ibt1);
@@ -150,12 +147,9 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
150147

151148
ModuleBase::Vector3<int> dR2(GridD.getBox(ad2).x, GridD.getBox(ad2).y, GridD.getBox(ad2).z);
152149

153-
if constexpr (std::is_same<TK, std::complex<double>>::value)
150+
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
154151
{
155-
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
156-
{
157-
continue;
158-
}
152+
continue;
159153
}
160154

161155
auto col_indexes = pv.get_indexes_col(ibt2);

0 commit comments

Comments
 (0)