Skip to content

Commit 188c3f1

Browse files
committed
Fix mpi bug for DeePKS and add some timer ticks.
1 parent 74284f9 commit 188c3f1

File tree

14 files changed

+128
-88
lines changed

14 files changed

+128
-88
lines changed

source/module_esolver/lcao_before_scf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
227227

228228
if (PARAM.inp.deepks_out_unittest)
229229
{
230-
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha);
230+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha, GlobalV::MY_RANK);
231231
}
232232
}
233233
#endif

source/module_esolver/lcao_others.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
233233

234234
if (PARAM.inp.deepks_out_unittest)
235235
{
236-
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha);
236+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha, GlobalV::MY_RANK);
237237
}
238238
}
239239
#endif

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
5353
std::vector<int> na)
5454
{
5555
ModuleBase::TITLE("LCAO_Deepks", "init");
56+
ModuleBase::timer::tick("LCAO_Deepks", "init");
5657

5758
GlobalV::ofs_running << " Initialize the descriptor index for DeePKS (lcao line)" << std::endl;
5859

@@ -123,6 +124,7 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
123124

124125
this->pv = &pv_in;
125126

127+
ModuleBase::timer::tick("LCAO_Deepks", "init");
126128
return;
127129
}
128130

@@ -169,6 +171,7 @@ void LCAO_Deepks::init_index(const int ntype,
169171
void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
170172
{
171173
ModuleBase::TITLE("LCAO_Deepks", "allocate_V_delta");
174+
ModuleBase::timer::tick("LCAO_Deepks", "allocate_V_delta");
172175

173176
// initialize the H matrix H_V_delta
174177
if (PARAM.globalv.gamma_only_local)
@@ -205,6 +208,7 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
205208
ModuleBase::GlobalFunc::ZEROS(this->gedm[inl], pdm_size);
206209
}
207210

211+
ModuleBase::timer::tick("LCAO_Deepks", "allocate_V_delta");
208212
return;
209213
}
210214

source/module_hamilt_lcao/module_deepks/LCAO_deepks_io.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,11 @@ void LCAO_deepks_io::save_matrix2npy(const std::string& file_name,
261261
template <typename T>
262262
void LCAO_deepks_io::save_tensor2npy(const std::string& file_name, const torch::Tensor& tensor, const int rank)
263263
{
264+
ModuleBase::TITLE("LCAO_deepks_io", "save_tensor2npy");
264265
if (rank != 0)
265266
{
266267
return;
267268
}
268-
ModuleBase::TITLE("LCAO_deepks_io", "save_tensor2npy");
269269
const int dim = tensor.dim();
270270
std::vector<long unsigned> shape(dim);
271271
for (int i = 0; i < dim; i++)

source/module_hamilt_lcao/module_deepks/deepks_basic.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void DeePKS_domain::cal_gevdm(const int nat,
1515
std::vector<torch::Tensor>& gevdm)
1616
{
1717
ModuleBase::TITLE("DeePKS_domain", "cal_gevdm");
18+
ModuleBase::timer::tick("DeePKS_domain", "cal_gevdm");
1819
// cal gevdm(d(EigenValue(D))/dD)
1920
int nlmax = inlmax / nat;
2021
for (int nl = 0; nl < nlmax; ++nl)
@@ -48,12 +49,14 @@ void DeePKS_domain::cal_gevdm(const int nat,
4849
gevdm.push_back(avmm);
4950
}
5051
assert(gevdm.size() == nlmax);
52+
ModuleBase::timer::tick("DeePKS_domain", "cal_gevdm");
5153
return;
5254
}
5355

5456
void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script::Module& model)
5557
{
5658
ModuleBase::TITLE("DeePKS_domain", "load_model");
59+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
5760

5861
try
5962
{
@@ -64,6 +67,7 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script
6467
std::cerr << "error loading the model" << std::endl;
6568
return;
6669
}
70+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
6771
return;
6872
}
6973

@@ -132,6 +136,7 @@ void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
132136
double& E_delta)
133137
{
134138
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm_equiv");
139+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm_equiv");
135140

136141
LCAO_deepks_io::save_npy_d(nat,
137142
des_per_atom,
@@ -157,6 +162,9 @@ void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
157162

158163
std::string cmd = "rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy";
159164
std::system(cmd.c_str());
165+
166+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm_equiv");
167+
return;
160168
}
161169

162170
// obtain from the machine learning model dE_delta/dDescriptor
@@ -179,6 +187,7 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
179187
return;
180188
}
181189
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm");
190+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm");
182191

183192
// forward
184193
std::vector<torch::jit::IValue> inputs;
@@ -213,6 +222,7 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
213222
}
214223
}
215224
}
225+
ModuleBase::timer::tick("DeePKS_domain", "cal_edelta_gedm");
216226
return;
217227
}
218228

source/module_hamilt_lcao/module_deepks/deepks_force.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
2525
ModuleBase::matrix& svnl_dalpha)
2626
{
2727
ModuleBase::TITLE("DeePKS_domain", "cal_f_delta");
28-
28+
ModuleBase::timer::tick("DeePKS_domain", "cal_f_delta");
2929
f_delta.zero_out();
3030

3131
const double Rcut_Alpha = orb.Alpha[0].getRcut();
3232
const int lmaxd = orb.get_lmax_d();
33-
const int nrow = pv.nrow;
3433

3534
for (int T0 = 0; T0 < ucell.ntype; T0++)
3635
{
@@ -146,6 +145,22 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
146145
}
147146
}
148147

148+
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
149+
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
150+
if (overlap_1 == nullptr || overlap_2 == nullptr)
151+
{
152+
continue;
153+
}
154+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
155+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
156+
for (int i = 0; i < 3; ++i)
157+
{
158+
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
159+
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
160+
}
161+
162+
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
163+
149164
const double* dm_current = dm_pair.get_pointer();
150165

151166
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
@@ -155,18 +170,6 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
155170
double nlm[3] = {0, 0, 0};
156171
double nlm_t[3] = {0, 0, 0}; // for stress
157172

158-
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
159-
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
160-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
161-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
162-
for (int i = 0; i < 3; ++i)
163-
{
164-
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
165-
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
166-
}
167-
168-
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
169-
170173
if (!PARAM.inp.deepks_equiv)
171174
{
172175
int ib = 0;
@@ -310,7 +313,7 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
310313
}
311314
}
312315
}
313-
316+
ModuleBase::timer::tick("DeePKS_domain", "cal_f_delta");
314317
return;
315318
}
316319

source/module_hamilt_lcao/module_deepks/deepks_fpre.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,27 @@ void DeePKS_domain::cal_gdmx(const int lmaxd,
136136

137137
dm_current = dm_pair.get_pointer();
138138

139-
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
139+
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
140+
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
141+
if (overlap_1 == nullptr || overlap_2 == nullptr)
140142
{
141-
for (int iw2 = 0; iw2 < col_indexes.size(); ++iw2)
142-
{
143-
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
144-
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
145-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
146-
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
143+
continue;
144+
}
145+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_1(3);
146+
std::vector<hamilt::BaseMatrix<double>*> grad_overlap_2(3);
147147

148-
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
148+
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
149149

150-
for (int i = 0; i < 3; ++i)
151-
{
152-
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
153-
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
154-
}
150+
for (int i = 0; i < 3; ++i)
151+
{
152+
grad_overlap_1[i] = phialpha[i + 1]->find_matrix(iat, ibt1, dR1);
153+
grad_overlap_2[i] = phialpha[i + 1]->find_matrix(iat, ibt2, dR2);
154+
}
155155

156+
for (int iw1 = 0; iw1 < row_indexes.size(); ++iw1)
157+
{
158+
for (int iw2 = 0; iw2 < col_indexes.size(); ++iw2)
159+
{
156160
int ib = 0;
157161
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
158162
{
@@ -266,7 +270,7 @@ void DeePKS_domain::cal_gvx(const int nat,
266270
torch::Tensor& gvx)
267271
{
268272
ModuleBase::TITLE("DeePKS_domain", "cal_gvx");
269-
273+
ModuleBase::timer::tick("DeePKS_domain", "cal_gvx");
270274
// gdmr : nat(derivative) * 3 * inl(projector) * nm * nm
271275
std::vector<torch::Tensor> gdmr;
272276
auto accessor = gdmx.accessor<double, 5>();
@@ -332,7 +336,7 @@ void DeePKS_domain::cal_gvx(const int nat,
332336
assert(gvx.size(2) == nat);
333337
assert(gvx.size(3) == des_per_atom);
334338
}
335-
339+
ModuleBase::timer::tick("DeePKS_domain", "cal_gvx");
336340
return;
337341
}
338342

source/module_hamilt_lcao/module_deepks/deepks_orbital.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void DeePKS_domain::cal_o_delta(const std::vector<TH>& dm_hl,
1515
const int nks)
1616
{
1717
ModuleBase::TITLE("DeePKS_domain", "cal_o_delta");
18+
ModuleBase::timer::tick("DeePKS_domain", "cal_o_delta");
1819

1920
for (int ik = 0; ik < nks; ik++)
2021
{
@@ -64,6 +65,7 @@ void DeePKS_domain::cal_o_delta(const std::vector<TH>& dm_hl,
6465
o_delta(ik, 0) = o_delta_tmp.real();
6566
}
6667
}
68+
ModuleBase::timer::tick("DeePKS_domain", "cal_o_delta");
6769
return;
6870
}
6971

source/module_hamilt_lcao/module_deepks/deepks_orbpre.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,9 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
9494

9595
ModuleBase::Vector3<int> dR1(GridD.getBox(ad1).x, GridD.getBox(ad1).y, GridD.getBox(ad1).z);
9696

97-
if constexpr (std::is_same<TK, std::complex<double>>::value)
97+
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
9898
{
99-
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
100-
{
101-
continue;
102-
}
99+
continue;
103100
}
104101

105102
auto row_indexes = pv.get_indexes_row(ibt1);
@@ -150,12 +147,9 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
150147

151148
ModuleBase::Vector3<int> dR2(GridD.getBox(ad2).x, GridD.getBox(ad2).y, GridD.getBox(ad2).z);
152149

153-
if constexpr (std::is_same<TK, std::complex<double>>::value)
150+
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
154151
{
155-
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
156-
{
157-
continue;
158-
}
152+
continue;
159153
}
160154

161155
auto col_indexes = pv.get_indexes_col(ibt2);

source/module_hamilt_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
102102

103103
{
104104
ModuleBase::TITLE("DeePKS_domain", "cal_pdm");
105+
ModuleBase::timer::tick("DeePKS_domain", "cal_pdm");
105106

106107
// if pdm has been initialized, skip the calculation
107108
if (init_pdm)
@@ -133,8 +134,6 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
133134
}
134135
}
135136

136-
ModuleBase::timer::tick("DeePKS_domain", "cal_pdm");
137-
138137
const double Rcut_Alpha = orb.Alpha[0].getRcut();
139138
for (int T0 = 0; T0 < ucell.ntype; T0++)
140139
{
@@ -204,12 +203,9 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
204203
}
205204

206205
ModuleBase::Vector3<int> dR1(GridD.getBox(ad1).x, GridD.getBox(ad1).y, GridD.getBox(ad1).z);
207-
if constexpr (std::is_same<TK, std::complex<double>>::value)
206+
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
208207
{
209-
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr)
210-
{
211-
continue;
212-
}
208+
continue;
213209
}
214210

215211
auto row_indexes = pv.get_indexes_row(ibt1);
@@ -242,12 +238,9 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
242238
const int nw2_tot = atom2->nw * PARAM.globalv.npol;
243239

244240
ModuleBase::Vector3<int> dR2(GridD.getBox(ad2).x, GridD.getBox(ad2).y, GridD.getBox(ad2).z);
245-
if constexpr (std::is_same<TK, std::complex<double>>::value)
241+
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
246242
{
247-
if (phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
248-
{
249-
continue;
250-
}
243+
continue;
251244
}
252245

253246
const double Rcut_AO2 = orb.Phi[T2].getRcut();

0 commit comments

Comments
 (0)