diff --git a/CMakeLists.txt b/CMakeLists.txt index 277f1924ec..704685480e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -282,6 +282,9 @@ if (USE_DSP) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a) endif() + +target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR}) + if (USE_SW) add_compile_definitions(__SW) set(SW ON) @@ -295,6 +298,7 @@ if (USE_SW) target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a) endif() + find_package(Threads REQUIRED) target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads) diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index a3c5f6d897..a5dc9a1991 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -12,16 +12,33 @@ extern "C" } namespace mtfunc { +std::complex* gemm_alpha_double=nullptr; +std::complex* gemm_beta_double=nullptr; +std::complex* gemm_alpha_float=nullptr; +std::complex* gemm_beta_float=nullptr; + void dspInitHandle(int id) { mt_blas_init(id); std::cout << " ** DSP inited on cluster " << id << " **" << std::endl; + mtfunc::gemm_alpha_double=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), id); + mtfunc::gemm_beta_double=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), id); + mtfunc::gemm_alpha_float=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), id); + mtfunc::gemm_beta_float=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), id); } // Use this at the beginning of the program to start a dsp cluster void dspDestoryHandle(int id) { hthread_dev_close(id); std::cout << " ** DSP closed on cluster " << id << " **" << std::endl; + mtfunc::free_ht(mtfunc::gemm_alpha_double); + mtfunc::free_ht(mtfunc::gemm_beta_double); + mtfunc::free_ht(mtfunc::gemm_alpha_float); + mtfunc::free_ht(mtfunc::gemm_beta_float); + mtfunc::gemm_alpha_double = nullptr; + mtfunc::gemm_beta_double = nullptr; + mtfunc::gemm_alpha_float = nullptr; + mtfunc::gemm_beta_float = nullptr; } // Close dsp cluster at the end MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) @@ -45,9 +62,7 @@ MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) void* malloc_ht(size_t bytes, int cluster_id) { - // std::cout << "MALLOC " << cluster_id; void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); - // std::cout << ptr << " SUCCEED" << std::endl;; return ptr; } @@ -55,9 +70,7 @@ void* malloc_ht(size_t bytes, int cluster_id) void free_ht(void* ptr) { - // std::cout << "FREE " << ptr; hthread_free(ptr); - // std::cout << " FREE SUCCEED" << std::endl; } // Used to replace original free @@ -271,22 +284,20 @@ void zgemm_mth_(const char* transa, const int* ldc, int cluster_id) { - std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); - *alp = *alpha; - std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); - *bet = *beta; + *gemm_alpha_double = *alpha; + *gemm_beta_double = *beta; mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, *n, *k, - alp, + gemm_alpha_double, a, *lda, b, *ldb, - bet, + gemm_beta_double, c, *ldc, cluster_id); @@ -308,10 +319,8 @@ void cgemm_mth_(const char* transa, const int* ldc, int cluster_id) { - std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); - *alp = *alpha; - std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); - *bet = *beta; + gemm_alpha_float = alpha; + gemm_beta_float = beta; mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, convertBLASTranspose(transa), @@ -319,17 +328,14 @@ void cgemm_mth_(const char* transa, *m, *n, *k, - (const void*)alp, + (const void*)gemm_alpha_float, (const void*)a, *lda, (const void*)b, *ldb, - (const void*)bet, + (const void*)gemm_beta_float, (void*)c, *ldc, cluster_id); - - free_ht(alp); - free_ht(bet); } // cgemm that needn't malloc_ht or free_ht } // namespace mtfunc \ No newline at end of file diff --git a/source/source_base/kernels/dsp/dsp_connector.h b/source/source_base/kernels/dsp/dsp_connector.h index 34ccbaec4b..30ca282193 100644 --- a/source/source_base/kernels/dsp/dsp_connector.h +++ b/source/source_base/kernels/dsp/dsp_connector.h @@ -15,7 +15,10 @@ void* malloc_ht(size_t bytes, int cluster_id); void free_ht(void* ptr); // mtblas functions - +extern std::complex* gemm_alpha_double; +extern std::complex* gemm_beta_double; +extern std::complex* gemm_alpha_float; +extern std::complex* gemm_beta_float; void sgemm_mt_(const char* transa, const char* transb, const int* m, diff --git a/source/source_basis/module_pw/module_fft/fft_dsp.cpp b/source/source_basis/module_pw/module_fft/fft_dsp.cpp index e26292cf5b..0842066eb0 100644 --- a/source/source_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/source_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,7 +1,7 @@ #include "fft_dsp.h" #include "source_base/global_variable.h" - +#include "source_base/tool_quit.h" #include #include #include diff --git a/source/source_io/cal_ldos.cpp b/source/source_io/cal_ldos.cpp index ec2f00bfc7..ecea39202f 100644 --- a/source/source_io/cal_ldos.cpp +++ b/source/source_io/cal_ldos.cpp @@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW>* pelec, for (int ib = 0; ib < nbands; ib++) { - pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik); + pelec->basis->recip_to_real,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik); const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV; double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib); @@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW>* pelec, for (int ib = 0; ib < nbands; ib++) { - pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik); + pelec->basis->recip_to_real,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik); const double weight = pelec->klist->wk[ik] / ucell.omega; for (int ir = 0; ir < pelec->basis->nrxx; ir++) diff --git a/source/source_io/cal_mlkedf_descriptors.cpp b/source/source_io/cal_mlkedf_descriptors.cpp index 58e3f8f259..24b7517963 100644 --- a/source/source_io/cal_mlkedf_descriptors.cpp +++ b/source/source_io/cal_mlkedf_descriptors.cpp @@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS( wfcr[ig] = psi->operator()(ibnd, ig) * std::complex(0.0, fact); } - pw_psi->recip2real(wfcr, wfcr, ik); + pw_psi->recip_to_real,base_device::DEVICE_CPU>(wfcr, wfcr, ik); for (int ir = 0; ir < this->nx; ++ir) { diff --git a/source/source_io/get_wf_lcao.cpp b/source/source_io/get_wf_lcao.cpp index 3d6c58a300..e784191b82 100644 --- a/source/source_io/get_wf_lcao.cpp +++ b/source/source_io/get_wf_lcao.cpp @@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell, // Calculate real-space wave functions psi_g.fix_k(is); std::vector> wfc_r(pw_wfc->nrxx); - pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is); // Extract real and imaginary parts std::vector wfc_real(pw_wfc->nrxx); @@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell, // Calculate real-space wave functions std::vector> wfc_r(pw_wfc->nrxx); - pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik); // Extract real and imaginary parts std::vector wfc_real(pw_wfc->nrxx); @@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc, } // call FFT - pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik); + pw_wfc->real_to_recip,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik); } #ifdef __MPI diff --git a/source/source_io/read_wf2rho_pw.cpp b/source/source_io/read_wf2rho_pw.cpp index 1be65a268c..66f41b9448 100644 --- a/source/source_io/read_wf2rho_pw.cpp +++ b/source/source_io/read_wf2rho_pw.cpp @@ -129,8 +129,8 @@ void ModuleIO::read_wf2rho_pw( { const std::complex* wfc_ib = wfc_tmp.c + ib * ng_npol; const std::complex* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2; - pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik); - pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib2, rho_tmp2.data(), ik); const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega; if (w1 != 0.0) @@ -152,7 +152,7 @@ void ModuleIO::read_wf2rho_pw( for (int ib = 0; ib < nbands; ++ib) { const std::complex* wfc_ib = wfc_tmp.c + ib * ng_npol; - pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik); const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega; diff --git a/source/source_io/to_wannier90_pw.cpp b/source/source_io/to_wannier90_pw.cpp index 9c33cf4976..72d628fe8d 100644 --- a/source/source_io/to_wannier90_pw.cpp +++ b/source/source_io/to_wannier90_pw.cpp @@ -253,7 +253,7 @@ void toWannier90_PW::out_unk( { int ib = cal_band_index[ib_w]; - wfcpw->recip2real(&psi_pw(ik, ib, 0), porter, ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(ik, ib, 0), porter, ik); if (GlobalV::RANK_IN_POOL == 0) { @@ -383,7 +383,7 @@ void toWannier90_PW::unkdotkb( } } - wfcpw->recip2real(phase, phase, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(phase, phase, cal_ik); if (PARAM.inp.nspin == 4) { @@ -396,17 +396,17 @@ void toWannier90_PW::unkdotkb( // (2) fft and get value // int npw_ik = wfcpw->npwk[cal_ik]; int npwx = wfcpw->npwk_max; - wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir_up, cal_ik); - // wfcpw->recip2real(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik); - wfcpw->recip2real(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir_up, cal_ik); + // wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psir_up[ir] *= phase[ir]; psir_dn[ir] *= phase[ir]; } - wfcpw->real2recip(psir_up, psir_up, cal_ikb); - wfcpw->real2recip(psir_dn, psir_dn, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir_up, psir_up, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir_dn, psir_dn, cal_ikb); for (int n = 0; n < num_bands; n++) { @@ -447,13 +447,12 @@ void toWannier90_PW::unkdotkb( ModuleBase::GlobalFunc::ZEROS(psir, wfcpw->nmaxgr); // (2) fft and get value - wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir, cal_ik); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psir[ir] *= phase[ir]; } - - wfcpw->real2recip(psir, psir, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir, psir, cal_ikb); for (int n = 0; n < num_bands; n++) { diff --git a/source/source_io/unk_overlap_pw.cpp b/source/source_io/unk_overlap_pw.cpp index 0c87b1f6fb..1b4af5e7b1 100644 --- a/source/source_io/unk_overlap_pw.cpp +++ b/source/source_io/unk_overlap_pw.cpp @@ -93,7 +93,7 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, } // (3) calculate the overlap in ik_L and ik_R - wfcpw->real2recip(psi_r, psi_r, ik_R); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R); for (int ig = 0; ig < evc->get_ngk(ik_R); ig++) { @@ -197,8 +197,8 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho // (2) fft and get value rhopw->recip2real(phase, phase); - wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); - wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); for (int ir = 0; ir < wfcpw->nrxx; ir++) { @@ -207,8 +207,8 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho } // (3) calculate the overlap in ik_L and ik_R - wfcpw->real2recip(psi_up, psi_up, ik_L); - wfcpw->real2recip(psi_down, psi_down, ik_L); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L); for (int i = 0; i < PARAM.globalv.npol; i++) { diff --git a/source/source_pw/module_pwdft/stress_func_exx.cpp b/source/source_pw/module_pwdft/stress_func_exx.cpp index 1885bb16ee..b01b24848a 100644 --- a/source/source_pw/module_pwdft/stress_func_exx.cpp +++ b/source/source_pw/module_pwdft/stress_func_exx.cpp @@ -260,7 +260,7 @@ void Stress_PW::stress_exx(ModuleBase::matrix& sigma, // psi_nk in real space d_psi_in->fix_kb(ik, nband); T* psi_nk = d_psi_in->get_pointer(); - wfcpw->recip2real(psi_nk, psi_nk_real, ik); + wfcpw->recip_to_real,Device>(psi_nk, psi_nk_real, ik); for (int iq = 0; iq < nqs; iq++) { @@ -269,7 +269,7 @@ void Stress_PW::stress_exx(ModuleBase::matrix& sigma, // psi_mq in real space d_psi_in->fix_kb(iq, mband); T* psi_mq = d_psi_in->get_pointer(); - wfcpw->recip2real(psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real,Device>(psi_mq, psi_mq_real, iq); // overlap density in real space setmem_complex_op()(density_real, 0.0, rhopw->nrxx); diff --git a/source/source_pw/module_stodft/sto_iter.cpp b/source/source_pw/module_stodft/sto_iter.cpp index aa9990a415..22e7a7329c 100644 --- a/source/source_pw/module_stodft/sto_iter.cpp +++ b/source/source_pw/module_stodft/sto_iter.cpp @@ -567,7 +567,11 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, const int npw = this->pkv->ngk[ik]; const double kweight = this->pkv->wk[ik]; T* hshchi = nullptr; + #ifdef __DSP + base_device::memory::resize_memory_op_mt()(hshchi, nchip_ik * npwx); + #else resmem_complex_op()(hshchi, nchip_ik * npwx); + #endif T* tmpin = stowf.shchi->get_pointer(); T* tmpout = hshchi; p_hamilt_sto->hPsi(tmpin, tmpout, nchip_ik); @@ -577,7 +581,11 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, tmpin += npwx; tmpout += npwx; } + #ifdef __DSP + base_device::memory::delete_memory_op_mt()(hshchi); + #else delmem_complex_op()(hshchi); + #endif } } #ifdef __MPI