Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)
endif()

target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})

if (USE_SW)
add_compile_definitions(__SW)
set(SW ON)
Expand All @@ -295,6 +298,7 @@ if (USE_SW)
target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a)
endif()


find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
6 changes: 4 additions & 2 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ extern "C"
}
namespace mtfunc
{
std::complex<double>* alp=nullptr;
std::complex<double>* bet=nullptr;
void dspInitHandle(int id)
{
mt_blas_init(id);
Expand Down Expand Up @@ -271,9 +273,9 @@ void zgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
// std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
// std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
Expand Down
2 changes: 2 additions & 0 deletions source/source_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ void* malloc_ht(size_t bytes, int cluster_id);
void free_ht(void* ptr);

// mtblas functions
extern std::complex<double>* alp;
extern std::complex<double>* bet;

void sgemm_mt_(const char* transa,
const char* transb,
Expand Down
2 changes: 1 addition & 1 deletion source/source_basis/module_pw/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "fft_dsp.h"

#include "source_base/global_variable.h"

#include "source_base/tool_quit.h"
#include <iostream>
#include <string.h>
#include <vector>
Expand Down
2 changes: 2 additions & 0 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
#ifdef __DSP
std::cout << " ** Initializing DSP Hardware..." << std::endl;
mtfunc::dspInitHandle(GlobalV::MY_RANK);
mtfunc::alp=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), GlobalV::MY_RANK);
mtfunc::bet=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), GlobalV::MY_RANK);
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions source/source_io/cal_ldos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);

const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib);
Expand Down Expand Up @@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);
const double weight = pelec->klist->wk[ik] / ucell.omega;

for (int ir = 0; ir < pelec->basis->nrxx; ir++)
Expand Down
2 changes: 1 addition & 1 deletion source/source_io/cal_mlkedf_descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS(
wfcr[ig] = psi->operator()(ibnd, ig) * std::complex<double>(0.0, fact);
}

pw_psi->recip2real(wfcr, wfcr, ik);
pw_psi->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfcr, wfcr, ik);

for (int ir = 0; ir < this->nx; ++ir)
{
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/get_wf_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
// Calculate real-space wave functions
psi_g.fix_k(is);
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,

// Calculate real-space wave functions
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc,
}

// call FFT
pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik);
pw_wfc->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik);
}

#ifdef __MPI
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/read_wf2rho_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ void ModuleIO::read_wf2rho_pw(
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
const std::complex<double>* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib2, rho_tmp2.data(), ik);
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

if (w1 != 0.0)
Expand All @@ -152,7 +152,7 @@ void ModuleIO::read_wf2rho_pw(
for (int ib = 0; ib < nbands; ++ib)
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);

const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

Expand Down
19 changes: 9 additions & 10 deletions source/source_io/to_wannier90_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ void toWannier90_PW::out_unk(
{
int ib = cal_band_index[ib_w];

wfcpw->recip2real(&psi_pw(ik, ib, 0), porter, ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(ik, ib, 0), porter, ik);

if (GlobalV::RANK_IN_POOL == 0)
{
Expand Down Expand Up @@ -383,7 +383,7 @@ void toWannier90_PW::unkdotkb(
}
}

wfcpw->recip2real(phase, phase, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(phase, phase, cal_ik);

if (PARAM.inp.nspin == 4)
{
Expand All @@ -396,17 +396,17 @@ void toWannier90_PW::unkdotkb(
// (2) fft and get value
// int npw_ik = wfcpw->npwk[cal_ik];
int npwx = wfcpw->npwk_max;
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip2real(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip2real(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir_up[ir] *= phase[ir];
psir_dn[ir] *= phase[ir];
}

wfcpw->real2recip(psir_up, psir_up, cal_ikb);
wfcpw->real2recip(psir_dn, psir_dn, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_up, psir_up, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_dn, psir_dn, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down Expand Up @@ -447,13 +447,12 @@ void toWannier90_PW::unkdotkb(
ModuleBase::GlobalFunc::ZEROS(psir, wfcpw->nmaxgr);

// (2) fft and get value
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir[ir] *= phase[ir];
}

wfcpw->real2recip(psir, psir, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir, psir, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down
10 changes: 5 additions & 5 deletions source/source_io/unk_overlap_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_r, psi_r, ik_R);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R);

for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
{
Expand Down Expand Up @@ -197,8 +197,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho

// (2) fft and get value
rhopw->recip2real(phase, phase);
wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);

for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
Expand All @@ -207,8 +207,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_up, psi_up, ik_L);
wfcpw->real2recip(psi_down, psi_down, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L);

for (int i = 0; i < PARAM.globalv.npol; i++)
{
Expand Down
4 changes: 2 additions & 2 deletions source/source_pw/module_pwdft/stress_func_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_nk in real space
d_psi_in->fix_kb(ik, nband);
T* psi_nk = d_psi_in->get_pointer();
wfcpw->recip2real(psi_nk, psi_nk_real, ik);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_nk, psi_nk_real, ik);

for (int iq = 0; iq < nqs; iq++)
{
Expand All @@ -269,7 +269,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_mq in real space
d_psi_in->fix_kb(iq, mband);
T* psi_mq = d_psi_in->get_pointer();
wfcpw->recip2real(psi_mq, psi_mq_real, iq);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_mq, psi_mq_real, iq);

// overlap density in real space
setmem_complex_op()(density_real, 0.0, rhopw->nrxx);
Expand Down
Loading