Skip to content

Commit efdc9f8

Browse files
committed
Merge branch 'develop' of https://github.com/mohanchen/abacus-mc into develop
2 parents 26d8cf5 + 10a9c41 commit efdc9f8

File tree

17 files changed

+172
-114
lines changed

17 files changed

+172
-114
lines changed

source/source_base/module_device/device.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ std::string get_device_flag(const std::string& device,
6565
/**
6666
* @brief Get the rank of current node
6767
* Note that GPU can only be binded with CPU in the same node
68-
*
69-
* @return int
68+
*
69+
* @return int
7070
*/
7171
int get_node_rank();
7272
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
@@ -91,6 +91,14 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str
9191
return;
9292
}
9393

94+
#if defined(__CUDA) || defined(__ROCM)
95+
template <>
96+
void print_device_info<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU *ctx, std::ofstream &ofs_device);
97+
98+
template <>
99+
void record_device_memory<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* dev, std::ofstream& ofs_device, std::string str, size_t size);
100+
#endif
101+
94102
} // end of namespace information
95103
} // end of namespace base_device
96104

source/source_base/module_device/output_device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void print_device_info<base_device::DEVICE_GPU>(
190190
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
191191
}
192192
int dev = 0, driverVersion = 0, runtimeVersion = 0;
193-
cudaErrcheck(cudaSetDevice(dev));
193+
cudaErrcheck(cudaGetDevice(&dev));
194194
cudaDeviceProp deviceProp;
195195
cudaErrcheck(cudaGetDeviceProperties(&deviceProp, dev));
196196
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;
@@ -429,7 +429,7 @@ void print_device_info<base_device::DEVICE_GPU>(
429429
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
430430
}
431431
int dev = 0, driverVersion = 0, runtimeVersion = 0;
432-
hipErrcheck(hipSetDevice(dev));
432+
hipErrcheck(hipGetDevice(&dev));
433433
hipDeviceProp_t deviceProp;
434434
hipErrcheck(hipGetDeviceProperties(&deviceProp, dev));
435435
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;

source/source_base/timer.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
#include "source_base/formatter.h"
1616

1717
#if defined(__CUDA) && defined(__USE_NVTX)
18-
#include <nvToolsExt.h>
18+
#if CUDA_VERSION < 12090
19+
#include "nvToolsExt.h"
20+
#else
21+
#include "nvtx3/nvToolsExt.h"
22+
#endif
1923
#include "source_io/module_parameter/parameter.h"
2024
#endif
2125

source/source_hsolver/test/CMakeLists.txt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,19 @@ install(FILES diago_pexsi_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DI
153153
install(FILES parallel_k2d_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
154154

155155

156-
157-
AddTest(
158-
TARGET MODULE_HSOLVER_diago_hs_parallel
159-
LIBS parameter ${math_libs} ELPA::ELPA base device MPI::MPI_CXX genelpa psi
160-
SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_elpa.cpp ../diago_scalapack.cpp
161-
)
156+
if (USE_ELPA)
157+
AddTest(
158+
TARGET MODULE_HSOLVER_diago_hs_parallel
159+
LIBS parameter ${math_libs} ELPA::ELPA base device MPI::MPI_CXX genelpa psi
160+
SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_elpa.cpp ../diago_scalapack.cpp
161+
)
162+
else()
163+
AddTest(
164+
TARGET MODULE_HSOLVER_diago_hs_parallel
165+
LIBS parameter ${math_libs} base device MPI::MPI_CXX psi
166+
SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_scalapack.cpp
167+
)
168+
endif()
162169

163170
AddTest(
164171
TARGET MODULE_HSOLVER_linear_trans

source/source_hsolver/test/test_diago_hs_para.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ void test_performance(int lda, int nb, int nbands, MPI_Comm comm,int case_numb,
160160
MPI_Comm_size(comm, &nproc);
161161

162162
std::vector<T> h_mat, s_mat, wfc, h_psi, s_psi;
163+
#ifdef __ELPA
163164
std::vector<typename GetTypeReal<T>::type> ekb_elpa(lda);
165+
#endif
164166
std::vector<typename GetTypeReal<T>::type> ekb_scalap(lda);
165167
std::vector<typename GetTypeReal<T>::type> ekb_lapack(lda);
166168

@@ -176,32 +178,36 @@ void test_performance(int lda, int nb, int nbands, MPI_Comm comm,int case_numb,
176178
}
177179

178180
// store all the times in a vector
181+
#ifdef __ELPA
179182
std::vector<double> time_elpa(case_numb, 0);
183+
#endif
180184
std::vector<double> time_scalap(case_numb, 0);
181185
std::vector<double> time_lapack(case_numb, 0);
182186

183187
if (my_rank == 0) { std::cout << "Random matrix ";
184188
}
185-
for (int randomi = 0; randomi < case_numb; ++randomi)
189+
for (int randomi = 0; randomi < case_numb; ++randomi)
186190
{
187-
191+
188192
if (my_rank == 0) {
189193
std::cout << randomi << " ";
190194
generate_random_hs(lda, randomi, h_mat, s_mat);
191195
}
192-
196+
auto start = std::chrono::high_resolution_clock::now();
197+
auto end = std::chrono::high_resolution_clock::now();
198+
#ifdef __ELPA
193199
// ELPA
194200
MPI_Barrier(comm);
195-
auto start = std::chrono::high_resolution_clock::now();
201+
start = std::chrono::high_resolution_clock::now();
196202
for (int j=0;j<loop_numb;j++)
197203
{
198204
hsolver::diago_hs_para<T>(h_mat.data(), s_mat.data(), lda, nbands,ekb_elpa.data(), wfc.data(), comm, 1, nb);
199205
MPI_Barrier(comm);
200206
}
201207
MPI_Barrier(comm);
202-
auto end = std::chrono::high_resolution_clock::now();
208+
end = std::chrono::high_resolution_clock::now();
203209
time_elpa[randomi] = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
204-
210+
#endif
205211

206212
// scalapack
207213
start = std::chrono::high_resolution_clock::now();
@@ -215,8 +221,8 @@ void test_performance(int lda, int nb, int nbands, MPI_Comm comm,int case_numb,
215221
time_scalap[randomi] = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
216222

217223
//LApack
218-
if (my_rank == 0)
219-
{
224+
if (my_rank == 0)
225+
{
220226
std::vector<T> h_tmp, s_tmp;
221227
start = std::chrono::high_resolution_clock::now();
222228
base_device::DEVICE_CPU* ctx = {};
@@ -239,26 +245,34 @@ void test_performance(int lda, int nb, int nbands, MPI_Comm comm,int case_numb,
239245

240246
//COMPARE EKB
241247
for (int i = 0; i < nbands; ++i) {
242-
typename GetTypeReal<T>::type diff_elpa_lapack = std::abs(ekb_elpa[i] - ekb_lapack[i]);
243248
typename GetTypeReal<T>::type diff_scalap_lapack = std::abs(ekb_scalap[i] - ekb_lapack[i]);
249+
#ifdef __ELPA
250+
typename GetTypeReal<T>::type diff_elpa_lapack = std::abs(ekb_elpa[i] - ekb_lapack[i]);
244251
if (diff_elpa_lapack > 1e-6 || diff_scalap_lapack > 1e-6)
252+
#else
253+
if (diff_scalap_lapack > 1e-6)
254+
#endif
245255
{
256+
#ifdef __ELPA
246257
std::cout << "eigenvalue " << i << " by ELPA: " << ekb_elpa[i] << std::endl;
258+
#endif
247259
std::cout << "eigenvalue " << i << " by Scalapack: " << ekb_scalap[i] << std::endl;
248260
std::cout << "eigenvalue " << i << " by Lapack: " << ekb_lapack[i] << std::endl;
249261
}
250262
}
251263
}
252-
MPI_Barrier(comm);
264+
MPI_Barrier(comm);
253265

254266
}
255267

256268
if (my_rank == 0)
257269
{
270+
#ifdef __ELPA
258271
std::cout << "\nELPA Time : ";
259272
for (int i=0; i < case_numb;i++)
260273
{std::cout << time_elpa[i] << " ";}
261274
std::cout << std::endl;
275+
#endif
262276

263277
std::cout << "scalapack Time: ";
264278
for (int i=0; i < case_numb;i++)
@@ -271,21 +285,29 @@ void test_performance(int lda, int nb, int nbands, MPI_Comm comm,int case_numb,
271285
std::cout << std::endl;
272286

273287
// print out the average time and speedup
288+
#ifdef __ELPA
274289
double avg_time_elpa = 0;
290+
#endif
275291
double avg_time_scalap = 0;
276292
double avg_time_lapack = 0;
277293
for (int i=0; i < case_numb;i++)
278294
{
295+
#ifdef __ELPA
279296
avg_time_elpa += time_elpa[i];
297+
#endif
280298
avg_time_scalap += time_scalap[i];
281299
avg_time_lapack += time_lapack[i];
282300
}
283301

302+
#ifdef __ELPA
284303
avg_time_elpa /= case_numb;
304+
#endif
285305
avg_time_scalap /= case_numb;
286306
avg_time_lapack /= case_numb;
287307
std::cout << "Average Lapack Time : " << avg_time_lapack << " ms" << std::endl;
308+
#ifdef __ELPA
288309
std::cout << "Average ELPA Time : " << avg_time_elpa << " ms, Speedup: " << avg_time_lapack / avg_time_elpa << std::endl;
310+
#endif
289311
std::cout << "Average Scalapack Time: " << avg_time_scalap << " ms, Speedup: " << avg_time_lapack / avg_time_scalap << std::endl;
290312
}
291313
}

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,24 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
618618
// 7. atom.npy, box.npy, overlap.npy
619619
//================================================================================
620620

621+
if( ((PARAM.inp.deepks_out_labels == 2) && is_after_scf )
622+
|| ( PARAM.inp.deepks_out_freq_elec ) )// need overlap when deepks_out_freq_elec
623+
{
624+
if (PARAM.inp.deepks_v_delta > 0)
625+
{
626+
// prepare for overlap.npy, very much like h_tot except for p_ham->getSk()
627+
std::vector<TH> s_tot(nks);
628+
DeePKS_domain::get_h_tot<TK, TH, TR>(*ParaV, p_ham, s_tot, nlocal, nks, 'S');
629+
const std::string file_stot = get_filename("overlap", PARAM.inp.deepks_out_labels, iter);
630+
LCAO_deepks_io::save_npy_h<TK, TH>(s_tot,
631+
file_stot,
632+
nlocal,
633+
nks,
634+
rank,
635+
1.0); // don't need unit_scale for overlap
636+
}
637+
}
638+
621639
if ( is_after_scf ) // don't need to output in multiple electronic steps
622640
{
623641
if (PARAM.inp.deepks_out_labels == 2)
@@ -632,20 +650,6 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
632650
DeePKS_domain::prepare_box(ucell, box_out);
633651
const std::string file_box = PARAM.globalv.global_out_dir + "deepks_box.npy";
634652
LCAO_deepks_io::save_tensor2npy<double>(file_box, box_out, rank);
635-
636-
if (PARAM.inp.deepks_v_delta > 0)
637-
{
638-
// prepare for overlap.npy, very much like h_tot except for p_ham->getSk()
639-
std::vector<TH> s_tot(nks);
640-
DeePKS_domain::get_h_tot<TK, TH, TR>(*ParaV, p_ham, s_tot, nlocal, nks, 'S');
641-
const std::string file_stot = PARAM.globalv.global_out_dir + "deepks_overlap.npy";
642-
LCAO_deepks_io::save_npy_h<TK, TH>(s_tot,
643-
file_stot,
644-
nlocal,
645-
nks,
646-
rank,
647-
1.0); // don't need unit_scale for overlap
648-
}
649653
}
650654

651655
//================================================================================

source/source_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Gint {
265265
std::vector<hamilt::HContainer<double>*> DMRGint;
266266

267267
//! tmp tools used in transfer_DM2DtoGrid
268-
hamilt::HContainer<double>* DMRGint_full = nullptr;
268+
hamilt::HContainer<double>* dm2d_tmp = nullptr;
269269

270270
std::vector<hamilt::HContainer<double>> pvdpRx_reduced;
271271
std::vector<hamilt::HContainer<double>> pvdpRy_reduced;

source/source_lcao/module_gint/gint_old.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Gint::~Gint() {
3333
delete this->hRGint_tmp[is];
3434
}
3535
#ifdef __MPI
36-
delete this->DMRGint_full;
36+
delete this->dm2d_tmp;
3737
#endif
3838
}
3939

@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172172
}
173173
#ifdef __MPI
174-
if (this->DMRGint_full != nullptr) {
175-
delete this->DMRGint_full;
174+
if (this->dm2d_tmp != nullptr) {
175+
delete this->dm2d_tmp;
176176
}
177-
this->DMRGint_full = new hamilt::HContainer<double>(ucell_in.nat);
178177
#endif
179178
}
180179

@@ -210,12 +209,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
210209
ModuleBase::Memory::record("Gint::DMRGint",
211210
this->DMRGint[0]->get_memory_size()
212211
* this->DMRGint.size()*nspin);
213-
#ifdef __MPI
214-
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
215-
this->DMRGint_full->allocate(nullptr, true);
216-
ModuleBase::Memory::record("Gint::DMRGint_full",
217-
this->DMRGint_full->get_memory_size());
218-
#endif
219212
}
220213
}
221214

@@ -231,9 +224,7 @@ void Gint::reset_DMRGint(const int& nspin)
231224
{
232225
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
233226
#ifdef __MPI
234-
delete this->DMRGint_full;
235-
this->DMRGint_full = new hamilt::HContainer<double>(*this->hRGint);
236-
this->DMRGint_full->allocate(nullptr, false);
227+
delete this->dm2d_tmp;
237228
#endif
238229
}
239230
}
@@ -262,37 +253,46 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262253
} else // NSPIN=4 case
263254
{
264255
#ifdef __MPI
265-
hamilt::transferParallels2Serials(*DM2D[0], this->DMRGint_full);
266-
#else
267-
this->DMRGint_full = DM2D[0];
268-
#endif
269-
std::vector<double*> tmp_pointer(4, nullptr);
270-
for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
271-
auto& ap = this->DMRGint_full->get_atom_pair(iap);
272-
int iat1 = ap.get_atom_i();
273-
int iat2 = ap.get_atom_j();
274-
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
275-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
276-
for (int is = 0; is < 4; is++) {
277-
tmp_pointer[is] = this->DMRGint[is]
278-
->find_matrix(iat1, iat2, r_index)
279-
->get_pointer();
280-
}
281-
double* data_full = ap.get_pointer(ir);
282-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
283-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284-
*(tmp_pointer[0])++ = data_full[icol];
285-
*(tmp_pointer[1])++ = data_full[icol + 1];
286-
}
287-
data_full += ap.get_col_size();
288-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289-
*(tmp_pointer[2])++ = data_full[icol];
290-
*(tmp_pointer[3])++ = data_full[icol + 1];
256+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257+
const int row_set[4] = {0, 0, 1, 1};
258+
const int col_set[4] = {0, 1, 0, 1};
259+
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
260+
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
261+
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
262+
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
263+
std::vector<int> iat2iwt(ucell->nat);
264+
for (int iat = 0; iat < ucell->nat; iat++) {
265+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
266+
}
267+
Parallel_Orbitals *pv = new Parallel_Orbitals();
268+
pv->set(mg, ng, nb, blacs_ctxt);
269+
pv->set_atomic_trace(iat2iwt.data(), ucell->nat, mg);
270+
auto ijr_info = DM2D[0]->get_ijr_info();
271+
this-> dm2d_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
272+
ModuleBase::Memory::record("Gint::dm2d_tmp", this->dm2d_tmp->get_memory_size());
273+
for (int is = 0; is < 4; is++){
274+
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
275+
auto& ap = DM2D[0]->get_atom_pair(iap);
276+
int iat1 = ap.get_atom_i();
277+
int iat2 = ap.get_atom_j();
278+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
279+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
280+
double* matrix_out = this -> dm2d_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
281+
double* matrix_in = ap.get_pointer(ir);
282+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
283+
for (int icol = 0; icol < ap.get_col_size()/2; icol++){
284+
int index_i = irow* ap.get_col_size()/2 + icol;
285+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
286+
matrix_out[index_i] = matrix_in[index_j];
287+
}
291288
}
292-
data_full += ap.get_col_size();
293289
}
294290
}
291+
hamilt::transferParallels2Serials( *(this->dm2d_tmp), this->DMRGint[is]);
295292
}
293+
#else
294+
//this->DMRGint_full = DM2D[0];
295+
#endif
296296
}
297297
ModuleBase::timer::tick("Gint", "transfer_DMR");
298298
}

0 commit comments

Comments
 (0)