Skip to content

Commit 80cc432

Browse files
authored
[Refactor] Replaced all exposed mathematical library interface (deepmodeling#6828)
* Replaced all exposed mathematical library interface * Add some other replacements * Fix uspp bug * Fix gemm
1 parent 2156598 commit 80cc432

File tree

20 files changed

+366
-365
lines changed

20 files changed

+366
-365
lines changed

source/source_estate/module_dm/cal_dm_psi.cpp

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -164,24 +164,24 @@ void psiMulPsiMpi(const psi::Psi<double>& psi1,
164164
const int nlocal = desc_dm[2];
165165
const int nbands = desc_psi[3];
166166

167-
pdgemm_(&N_char,
168-
&T_char,
169-
&nlocal,
170-
&nlocal,
171-
&nbands,
172-
&one_float,
167+
ScalapackConnector::gemm(N_char,
168+
T_char,
169+
nlocal,
170+
nlocal,
171+
nbands,
172+
one_float,
173173
psi1.get_pointer(),
174-
&one_int,
175-
&one_int,
174+
one_int,
175+
one_int,
176176
desc_psi,
177177
psi2.get_pointer(),
178-
&one_int,
179-
&one_int,
178+
one_int,
179+
one_int,
180180
desc_psi,
181-
&zero_float,
181+
zero_float,
182182
dm_out,
183-
&one_int,
184-
&one_int,
183+
one_int,
184+
one_int,
185185
desc_dm);
186186
ModuleBase::timer::tick("psiMulPsiMpi", "pdgemm");
187187
}
@@ -198,24 +198,24 @@ void psiMulPsiMpi(const psi::Psi<std::complex<double>>& psi1,
198198
const char N_char = 'N', T_char = 'T';
199199
const int nlocal = desc_dm[2];
200200
const int nbands = desc_psi[3];
201-
pzgemm_(&N_char,
202-
&T_char,
203-
&nlocal,
204-
&nlocal,
205-
&nbands,
206-
&one_complex,
201+
ScalapackConnector::gemm(N_char,
202+
T_char,
203+
nlocal,
204+
nlocal,
205+
nbands,
206+
one_complex,
207207
psi1.get_pointer(),
208-
&one_int,
209-
&one_int,
208+
one_int,
209+
one_int,
210210
desc_psi,
211211
psi2.get_pointer(),
212-
&one_int,
213-
&one_int,
212+
one_int,
213+
one_int,
214214
desc_psi,
215-
&zero_complex,
215+
zero_complex,
216216
dm_out,
217-
&one_int,
218-
&one_int,
217+
one_int,
218+
one_int,
219219
desc_dm);
220220
ModuleBase::timer::tick("psiMulPsiMpi", "pdgemm");
221221
}
@@ -229,19 +229,19 @@ void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, doubl
229229
const char N_char = 'N', T_char = 'T';
230230
const int nlocal = psi1.get_nbasis();
231231
const int nbands = psi1.get_nbands();
232-
dgemm_(&N_char,
233-
&T_char,
234-
&nlocal,
235-
&nlocal,
236-
&nbands,
237-
&one_float,
232+
BlasConnector::gemm_cm(N_char,
233+
T_char,
234+
nlocal,
235+
nlocal,
236+
nbands,
237+
one_float,
238238
psi1.get_pointer(),
239-
&nlocal,
239+
nlocal,
240240
psi2.get_pointer(),
241-
&nlocal,
242-
&zero_float,
241+
nlocal,
242+
zero_float,
243243
dm_out,
244-
&nlocal);
244+
nlocal);
245245
}
246246

247247
void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
@@ -254,19 +254,19 @@ void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
254254
const int nbands = psi1.get_nbands();
255255
const std::complex<double> one_complex = {1.0, 0.0};
256256
const std::complex<double> zero_complex = {0.0, 0.0};
257-
zgemm_(&N_char,
258-
&T_char,
259-
&nlocal,
260-
&nlocal,
261-
&nbands,
262-
&one_complex,
257+
BlasConnector::gemm_cm(N_char,
258+
T_char,
259+
nlocal,
260+
nlocal,
261+
nbands,
262+
one_complex,
263263
psi1.get_pointer(),
264-
&nlocal,
264+
nlocal,
265265
psi2.get_pointer(),
266-
&nlocal,
267-
&zero_complex,
266+
nlocal,
267+
zero_complex,
268268
dm_out,
269-
&nlocal);
269+
nlocal);
270270
}
271271

272272
} // namespace elecstate

source/source_io/output_mulliken.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -547,24 +547,24 @@ void Output_Mulliken<std::complex<double>>::cal_orbMulP()
547547
const char N_char = 'N';
548548
const int one_int = 1;
549549
const std::complex<double> one_float = {1.0, 0.0}, zero_float = {0.0, 0.0};
550-
pzgemm_(&N_char,
551-
&T_char,
552-
&nw,
553-
&nw,
554-
&nw,
555-
&one_float,
550+
ScalapackConnector::gemm(N_char,
551+
T_char,
552+
nw,
553+
nw,
554+
nw,
555+
one_float,
556556
p_DMk,
557-
&one_int,
558-
&one_int,
557+
one_int,
558+
one_int,
559559
this->ParaV_->desc,
560560
p_Sk,
561-
&one_int,
562-
&one_int,
561+
one_int,
562+
one_int,
563563
this->ParaV_->desc,
564-
&zero_float,
564+
zero_float,
565565
mud.c,
566-
&one_int,
567-
&one_int,
566+
one_int,
567+
one_int,
568568
this->ParaV_->desc);
569569
this->collect_MW(MecMulP, mud, nw, this->isk_[ik]);
570570
#endif
@@ -597,24 +597,24 @@ void Output_Mulliken<double>::cal_orbMulP()
597597
const char N_char = 'N';
598598
const int one_int = 1;
599599
const double one_float = 1.0, zero_float = 0.0;
600-
pdgemm_(&N_char,
601-
&T_char,
602-
&nw,
603-
&nw,
604-
&nw,
605-
&one_float,
600+
ScalapackConnector::gemm(N_char,
601+
T_char,
602+
nw,
603+
nw,
604+
nw,
605+
one_float,
606606
p_DMk,
607-
&one_int,
608-
&one_int,
607+
one_int,
608+
one_int,
609609
this->ParaV_->desc,
610610
p_Sk,
611-
&one_int,
612-
&one_int,
611+
one_int,
612+
one_int,
613613
this->ParaV_->desc,
614-
&zero_float,
614+
zero_float,
615615
mud.c,
616-
&one_int,
617-
&one_int,
616+
one_int,
617+
one_int,
618618
this->ParaV_->desc);
619619
if (this->nspin_ == 1 || this->nspin_ == 2)
620620
{

source/source_io/to_wannier90_lcao.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -423,44 +423,44 @@ void toWannier90_LCAO::unkdotkb(const UnitCell& ucell,
423423
ModuleBase::GlobalFunc::ZEROS(out_matrix, nloc);
424424

425425
#ifdef __MPI
426-
pzgemm_(&transa,
427-
&transb,
428-
&Bands,
429-
&nlocal,
430-
&nlocal,
431-
&alpha,
426+
ScalapackConnector::gemm(transa,
427+
transb,
428+
Bands,
429+
nlocal,
430+
nlocal,
431+
alpha,
432432
&psi_in(ik, 0, 0),
433-
&one,
434-
&one,
433+
one,
434+
one,
435435
this->ParaV->desc,
436436
midmatrix,
437-
&one,
438-
&one,
437+
one,
438+
one,
439439
this->ParaV->desc,
440-
&beta,
440+
beta,
441441
C_matrix,
442-
&one,
443-
&one,
442+
one,
443+
one,
444444
this->ParaV->desc);
445445

446-
pzgemm_(&transb,
447-
&transb,
448-
&Bands,
449-
&Bands,
450-
&nlocal,
451-
&alpha,
446+
ScalapackConnector::gemm(transb,
447+
transb,
448+
Bands,
449+
Bands,
450+
nlocal,
451+
alpha,
452452
C_matrix,
453-
&one,
454-
&one,
453+
one,
454+
one,
455455
this->ParaV->desc,
456456
&psi_in(ikb, 0, 0),
457-
&one,
458-
&one,
457+
one,
458+
one,
459459
this->ParaV->desc,
460-
&beta,
460+
beta,
461461
out_matrix,
462-
&one,
463-
&one,
462+
one,
463+
one,
464464
this->ParaV->desc);
465465
#endif
466466

source/source_io/unk_overlap_lcao.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -563,44 +563,44 @@ std::complex<double> unkOverlap_lcao::det_berryphase(const UnitCell& ucell,
563563
std::complex<double> alpha = {1.0, 0.0}, beta = {0.0, 0.0};
564564
int one = 1;
565565
#ifdef __MPI
566-
pzgemm_(&transa,
567-
&transb,
568-
&occBands,
569-
&nlocal,
570-
&nlocal,
571-
&alpha,
566+
ScalapackConnector::gemm(transa,
567+
transb,
568+
occBands,
569+
nlocal,
570+
nlocal,
571+
alpha,
572572
&psi_in[0](ik_L, 0, 0),
573-
&one,
574-
&one,
573+
one,
574+
one,
575575
para_orb.desc,
576576
midmatrix,
577-
&one,
578-
&one,
577+
one,
578+
one,
579579
para_orb.desc,
580-
&beta,
580+
beta,
581581
C_matrix,
582-
&one,
583-
&one,
582+
one,
583+
one,
584584
para_orb.desc);
585585

586-
pzgemm_(&transb,
587-
&transb,
588-
&occBands,
589-
&occBands,
590-
&nlocal,
591-
&alpha,
586+
ScalapackConnector::gemm(transb,
587+
transb,
588+
occBands,
589+
occBands,
590+
nlocal,
591+
alpha,
592592
C_matrix,
593-
&one,
594-
&one,
593+
one,
594+
one,
595595
para_orb.desc,
596596
&psi_in[0](ik_R, 0, 0),
597-
&one,
598-
&one,
597+
one,
598+
one,
599599
para_orb.desc,
600-
&beta,
600+
beta,
601601
out_matrix,
602-
&one,
603-
&one,
602+
one,
603+
one,
604604
para_orb.desc);
605605

606606
assert(para_orb.nrow>0);

source/source_lcao/module_deepks/deepks_orbpre.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,19 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
213213
gemm_alpha = 2.0;
214214
}
215215

216-
dgemm_(&transa,
217-
&transb,
218-
&row_size_nks,
219-
&trace_alpha_size,
220-
&col_size,
221-
&gemm_alpha,
222-
dm_array.data(),
223-
&col_size,
216+
BlasConnector::gemm(transb,
217+
transa,
218+
trace_alpha_size,
219+
row_size_nks,
220+
col_size,
221+
gemm_alpha,
224222
s_2t.data(),
225-
&col_size,
226-
&gemm_beta,
223+
col_size,
224+
dm_array.data(),
225+
col_size,
226+
gemm_beta,
227227
g_1dmt.data(),
228-
&row_size_nks);
228+
row_size_nks);
229229
} // ad2
230230

231231
for (int ik = 0; ik < nks; ik++)

0 commit comments

Comments
 (0)