Skip to content

Fix bug in dsp compute #6433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)
endif()

target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})

if (USE_SW)
add_compile_definitions(__SW)
set(SW ON)
Expand All @@ -295,6 +298,7 @@ if (USE_SW)
target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a)
endif()


find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
44 changes: 25 additions & 19 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,33 @@ extern "C"
}
namespace mtfunc
{
std::complex<double>* gemm_alpha_double=nullptr;
std::complex<double>* gemm_beta_double=nullptr;
std::complex<float>* gemm_alpha_float=nullptr;
std::complex<float>* gemm_beta_float=nullptr;

void dspInitHandle(int id)
{
mt_blas_init(id);
std::cout << " ** DSP inited on cluster " << id << " **" << std::endl;
mtfunc::gemm_alpha_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
mtfunc::gemm_beta_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
mtfunc::gemm_alpha_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
mtfunc::gemm_beta_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
} // Use this at the beginning of the program to start a dsp cluster

void dspDestoryHandle(int id)
{
hthread_dev_close(id);
std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
mtfunc::free_ht(mtfunc::gemm_alpha_double);
mtfunc::free_ht(mtfunc::gemm_beta_double);
mtfunc::free_ht(mtfunc::gemm_alpha_float);
mtfunc::free_ht(mtfunc::gemm_beta_float);
mtfunc::gemm_alpha_double = nullptr;
mtfunc::gemm_beta_double = nullptr;
mtfunc::gemm_alpha_float = nullptr;
mtfunc::gemm_beta_float = nullptr;
} // Close dsp cluster at the end

MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
Expand All @@ -45,19 +62,15 @@ MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)

void* malloc_ht(size_t bytes, int cluster_id)
{
// std::cout << "MALLOC " << cluster_id;
void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW);
// std::cout << ptr << " SUCCEED" << std::endl;;
return ptr;
}

// Used to replace original malloc

void free_ht(void* ptr)
{
// std::cout << "FREE " << ptr;
hthread_free(ptr);
// std::cout << " FREE SUCCEED" << std::endl;
}

// Used to replace original free
Expand Down Expand Up @@ -271,22 +284,20 @@ void zgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
*gemm_alpha_double = *alpha;
*gemm_beta_double = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
*n,
*k,
alp,
gemm_alpha_double,
a,
*lda,
b,
*ldb,
bet,
gemm_beta_double,
c,
*ldc,
cluster_id);
Expand All @@ -308,28 +319,23 @@ void cgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<float>* alp = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*alp = *alpha;
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*bet = *beta;
gemm_alpha_float = alpha;
gemm_beta_float = beta;

mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
*n,
*k,
(const void*)alp,
(const void*)gemm_alpha_float,
(const void*)a,
*lda,
(const void*)b,
*ldb,
(const void*)bet,
(const void*)gemm_beta_float,
(void*)c,
*ldc,
cluster_id);

free_ht(alp);
free_ht(bet);
} // cgemm that needn't malloc_ht or free_ht
} // namespace mtfunc
5 changes: 4 additions & 1 deletion source/source_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ void* malloc_ht(size_t bytes, int cluster_id);
void free_ht(void* ptr);

// mtblas functions

extern std::complex<double>* gemm_alpha_double;
extern std::complex<double>* gemm_beta_double;
extern std::complex<float>* gemm_alpha_float;
extern std::complex<float>* gemm_beta_float;
void sgemm_mt_(const char* transa,
const char* transb,
const int* m,
Expand Down
2 changes: 1 addition & 1 deletion source/source_basis/module_pw/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "fft_dsp.h"

#include "source_base/global_variable.h"

#include "source_base/tool_quit.h"
#include <iostream>
#include <string.h>
#include <vector>
Expand Down
4 changes: 2 additions & 2 deletions source/source_io/cal_ldos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);

const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib);
Expand Down Expand Up @@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);
const double weight = pelec->klist->wk[ik] / ucell.omega;

for (int ir = 0; ir < pelec->basis->nrxx; ir++)
Expand Down
2 changes: 1 addition & 1 deletion source/source_io/cal_mlkedf_descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS(
wfcr[ig] = psi->operator()(ibnd, ig) * std::complex<double>(0.0, fact);
}

pw_psi->recip2real(wfcr, wfcr, ik);
pw_psi->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfcr, wfcr, ik);

for (int ir = 0; ir < this->nx; ++ir)
{
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/get_wf_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
// Calculate real-space wave functions
psi_g.fix_k(is);
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,

// Calculate real-space wave functions
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc,
}

// call FFT
pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik);
pw_wfc->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik);
}

#ifdef __MPI
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/read_wf2rho_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ void ModuleIO::read_wf2rho_pw(
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
const std::complex<double>* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib2, rho_tmp2.data(), ik);
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

if (w1 != 0.0)
Expand All @@ -152,7 +152,7 @@ void ModuleIO::read_wf2rho_pw(
for (int ib = 0; ib < nbands; ++ib)
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);

const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

Expand Down
19 changes: 9 additions & 10 deletions source/source_io/to_wannier90_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ void toWannier90_PW::out_unk(
{
int ib = cal_band_index[ib_w];

wfcpw->recip2real(&psi_pw(ik, ib, 0), porter, ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(ik, ib, 0), porter, ik);

if (GlobalV::RANK_IN_POOL == 0)
{
Expand Down Expand Up @@ -383,7 +383,7 @@ void toWannier90_PW::unkdotkb(
}
}

wfcpw->recip2real(phase, phase, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(phase, phase, cal_ik);

if (PARAM.inp.nspin == 4)
{
Expand All @@ -396,17 +396,17 @@ void toWannier90_PW::unkdotkb(
// (2) fft and get value
// int npw_ik = wfcpw->npwk[cal_ik];
int npwx = wfcpw->npwk_max;
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip2real(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip2real(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir_up[ir] *= phase[ir];
psir_dn[ir] *= phase[ir];
}

wfcpw->real2recip(psir_up, psir_up, cal_ikb);
wfcpw->real2recip(psir_dn, psir_dn, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_up, psir_up, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_dn, psir_dn, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down Expand Up @@ -447,13 +447,12 @@ void toWannier90_PW::unkdotkb(
ModuleBase::GlobalFunc::ZEROS(psir, wfcpw->nmaxgr);

// (2) fft and get value
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir[ir] *= phase[ir];
}

wfcpw->real2recip(psir, psir, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir, psir, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down
10 changes: 5 additions & 5 deletions source/source_io/unk_overlap_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_r, psi_r, ik_R);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R);

for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
{
Expand Down Expand Up @@ -197,8 +197,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho

// (2) fft and get value
rhopw->recip2real(phase, phase);
wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);

for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
Expand All @@ -207,8 +207,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_up, psi_up, ik_L);
wfcpw->real2recip(psi_down, psi_down, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L);

for (int i = 0; i < PARAM.globalv.npol; i++)
{
Expand Down
4 changes: 2 additions & 2 deletions source/source_pw/module_pwdft/stress_func_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_nk in real space
d_psi_in->fix_kb(ik, nband);
T* psi_nk = d_psi_in->get_pointer();
wfcpw->recip2real(psi_nk, psi_nk_real, ik);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_nk, psi_nk_real, ik);

for (int iq = 0; iq < nqs; iq++)
{
Expand All @@ -269,7 +269,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_mq in real space
d_psi_in->fix_kb(iq, mband);
T* psi_mq = d_psi_in->get_pointer();
wfcpw->recip2real(psi_mq, psi_mq_real, iq);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_mq, psi_mq_real, iq);

// overlap density in real space
setmem_complex_op()(density_real, 0.0, rhopw->nrxx);
Expand Down
8 changes: 8 additions & 0 deletions source/source_pw/module_stodft/sto_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ void Stochastic_Iter<T, Device>::sum_stoeband(Stochastic_WF<T, Device>& stowf,
const int npw = this->pkv->ngk[ik];
const double kweight = this->pkv->wk[ik];
T* hshchi = nullptr;
#ifdef __DSP
base_device::memory::resize_memory_op_mt<T, Device>()(hshchi, nchip_ik * npwx);
#else
resmem_complex_op()(hshchi, nchip_ik * npwx);
#endif
T* tmpin = stowf.shchi->get_pointer();
T* tmpout = hshchi;
p_hamilt_sto->hPsi(tmpin, tmpout, nchip_ik);
Expand All @@ -577,7 +581,11 @@ void Stochastic_Iter<T, Device>::sum_stoeband(Stochastic_WF<T, Device>& stowf,
tmpin += npwx;
tmpout += npwx;
}
#ifdef __DSP
base_device::memory::delete_memory_op_mt<T, Device>()(hshchi);
#else
delmem_complex_op()(hshchi);
#endif
}
}
#ifdef __MPI
Expand Down
Loading