Skip to content

Commit e277c1c

Browse files
authored
Fix: dsp memory op (#7056)
* Fix dsp setmem op * Clean up the code
1 parent db8a9dd commit e277c1c

4 files changed

Lines changed: 58 additions & 17 deletions

File tree

source/source_base/module_device/memory_op.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,19 @@ struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
471471
}
472472
};
473473

474+
template <typename FPTYPE>
475+
struct set_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
476+
{
477+
void operator()(FPTYPE* arr, const int var, const size_t size)
478+
{
479+
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
480+
int beg = 0, len = 0;
481+
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
482+
memset(arr + beg, var, sizeof(FPTYPE) * len);
483+
});
484+
}
485+
};
486+
474487
template <typename FPTYPE>
475488
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
476489
{
@@ -487,6 +500,12 @@ template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
487500
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
488501
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
489502

503+
template struct set_memory_op_mt<int, base_device::DEVICE_CPU>;
504+
template struct set_memory_op_mt<float, base_device::DEVICE_CPU>;
505+
template struct set_memory_op_mt<double, base_device::DEVICE_CPU>;
506+
template struct set_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
507+
template struct set_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
508+
490509
template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
491510
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
492511
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;

source/source_base/module_device/memory_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ struct resize_memory_op_mt
234234
void operator()(FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
235235
};
236236

237+
template <typename FPTYPE, typename Device>
238+
struct set_memory_op_mt
239+
{
240+
/// @brief memset for DSP memory allocated by mt allocator.
241+
///
242+
/// Input Parameters
243+
/// \param var : the specified constant byte value
244+
/// \param size : array size
245+
///
246+
/// Output Parameters
247+
/// \param arr : output array initialized by the input value
248+
void operator()(FPTYPE* arr, const int var, const size_t size);
249+
};
250+
237251
template <typename FPTYPE, typename Device>
238252
struct delete_memory_op_mt
239253
{

source/source_pw/module_pwdft/op_pw_nl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,15 @@ class Nonlocal<OperatorPW<T, Device>> : public OperatorPW<T, Device>
8888
using gemv_op = ModuleBase::gemv_op<T, Device>;
8989
using gemm_op = ModuleBase::gemm_op<T, Device>;
9090
using nonlocal_op = nonlocal_pw_op<Real, Device>;
91-
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
92-
#ifdef __DSP
91+
#ifdef __DSP
92+
using setmem_complex_op = base_device::memory::set_memory_op_mt<T, Device>;
9393
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
9494
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
95-
#else
95+
#else
96+
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
9697
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
9798
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
98-
#endif
99+
#endif
99100
using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
100101

101102
T one{1, 0};
@@ -104,4 +105,4 @@ class Nonlocal<OperatorPW<T, Device>> : public OperatorPW<T, Device>
104105

105106
} // namespace hamilt
106107

107-
#endif
108+
#endif

source/source_pw/module_pwdft/vnl_pw.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ void pseudopot_cell_vnl::release_memory()
6464
delmem_ch_op()(this->c_deeq_nc);
6565
delmem_ch_op()(this->c_vkb);
6666
delmem_ch_op()(this->c_qq_so);
67+
#ifdef __DSP
68+
if (this->z_vkb != nullptr)
69+
{
70+
base_device::memory::delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>()(this->z_vkb);
71+
this->z_vkb = nullptr;
72+
}
73+
#endif
6774
// There's no need to delete double precision pointers while in a CPU environment.
6875
}
6976
memory_released = true;
@@ -273,13 +280,13 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
273280
resmem_sh_op()(s_tab, this->tab.getSize());
274281
resmem_ch_op()(c_vkb, nkb * npwx);
275282
}
276-
#ifdef __DSP
283+
#ifdef __DSP
277284
base_device::memory::resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>()
278-
(this->z_vkb, this->vkb.size, "Nonlocal<PW>::ps");
279-
memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16);
280-
#else
285+
(this->z_vkb, this->vkb.size, "VNL::z_vkb");
286+
// memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16);
287+
#else
281288
this->z_vkb = this->vkb.c;
282-
#endif
289+
#endif
283290
this->d_tab = this->tab.ptr;
284291
// There's no need to delete double precision pointers while in a CPU environment.
285292
}
@@ -293,12 +300,12 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
293300
// with structure factor, for all atoms, in reciprocal space
294301
//----------------------------------------------------------
295302
template <typename FPTYPE, typename Device>
296-
void pseudopot_cell_vnl::getvnl(Device* ctx,
303+
void pseudopot_cell_vnl::getvnl(Device* ctx,
297304
const UnitCell& ucell,
298-
const int& ik,
305+
const int& ik,
299306
std::complex<FPTYPE>* vkb_in) const
300307
{
301-
if (PARAM.inp.test_pp)
308+
if (PARAM.inp.test_pp)
302309
{
303310
ModuleBase::TITLE("pseudopot_cell_vnl", "getvnl");
304311
}
@@ -732,10 +739,10 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_
732739
for (int iq = 0; iq < PARAM.globalv.nqx; iq++)
733740
{
734741
const double q = iq * PARAM.globalv.dq;
735-
ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl);
742+
ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl);
736743
for (int ir = 0; ir < kkbeta; ir++)
737-
{
738-
aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir];
744+
{
745+
aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir];
739746
}
740747
double vqint=0.0;
741748
ModuleBase::Integral::Simpson_Integral(kkbeta, aux, cell.atoms[it].ncpp.rab.data(), vqint);
@@ -1723,7 +1730,7 @@ template void pseudopot_cell_vnl::getvnl<float, base_device::DEVICE_CPU>(base_de
17231730
int const&,
17241731
std::complex<float>*) const;
17251732
template void pseudopot_cell_vnl::getvnl<double, base_device::DEVICE_CPU>(base_device::DEVICE_CPU*,
1726-
const UnitCell&,
1733+
const UnitCell&,
17271734
int const&,
17281735
std::complex<double>*) const;
17291736
#if defined(__CUDA) || defined(__ROCM)

0 commit comments

Comments
 (0)