Skip to content

Commit d485c50

Browse files
committed
Remove all ctx parameters in set_memory_op
1 parent 851d8bf commit d485c50

28 files changed

+78
-84
lines changed

source/module_base/math_chebyshev.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
437437
funA(arrayn_1, arrayn, m);
438438

439439
// 0- & 1-st order
440-
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
440+
setmem_complex_op()(waveout, 0, ndmxt);
441441
std::complex<REAL> coef0 = std::complex<REAL>(coefr_cpu[0], 0);
442442
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coef0, arrayn_1, 1, waveout, 1);
443443
std::complex<REAL> coef1 = std::complex<REAL>(coefr_cpu[1], 0);
@@ -505,7 +505,7 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
505505
funA(arrayn_1, arrayn, m);
506506

507507
// 0- & 1-st order
508-
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
508+
setmem_complex_op()(waveout, 0, ndmxt);
509509
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[0], arrayn_1, 1, waveout, 1);
510510
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[1], arrayn, 1, waveout, 1);
511511
// for (int i = 0; i < ndmxt; ++i)

source/module_base/module_device/cuda/memory_op.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ __global__ void cast_memory(std::complex<FPTYPE_out>* out, const FPTYPE_in* in,
5353

5454
template <typename FPTYPE>
5555
void resize_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE*& arr,
56-
const size_t size,
56+
const size_t size,
5757
const char* record_in)
5858
{
5959
if (arr != nullptr)
@@ -78,8 +78,7 @@ void resize_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE*& arr,
7878
}
7979

8080
template <typename FPTYPE>
81-
void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
82-
FPTYPE* arr,
81+
void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE* arr,
8382
const int var,
8483
const size_t size)
8584
{

source/module_base/module_device/memory_op.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
4545
template <typename FPTYPE>
4646
struct set_memory_op<FPTYPE, base_device::DEVICE_CPU>
4747
{
48-
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr, const int var, const size_t size)
48+
void operator()(FPTYPE* arr, const int var, const size_t size)
4949
{
5050
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
5151
int beg = 0, len = 0;
@@ -166,7 +166,7 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_GPU>
166166
template <typename FPTYPE>
167167
struct set_memory_op<FPTYPE, base_device::DEVICE_GPU>
168168
{
169-
void operator()(const base_device::DEVICE_GPU* dev, FPTYPE* arr, const int var, const size_t size)
169+
void operator()(FPTYPE* arr, const int var, const size_t size)
170170
{
171171
}
172172
};

source/module_base/module_device/memory_op.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,12 @@ struct set_memory_op
3232
/// @brief memset for multi-device
3333
///
3434
/// Input Parameters
35-
/// \param dev : the type of computing device
3635
/// \param var : the specified constant value
3736
/// \param size : array size
3837
///
3938
/// Output Parameters
4039
/// \param arr : output array initialized by the input value
41-
void operator()(const Device* dev, FPTYPE* arr, const int var, const size_t size);
40+
void operator()(FPTYPE* arr, const int var, const size_t size);
4241
};
4342

4443
template <typename FPTYPE, typename Device_out, typename Device_in>
@@ -120,7 +119,7 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_GPU>
120119
template <typename FPTYPE>
121120
struct set_memory_op<FPTYPE, base_device::DEVICE_GPU>
122121
{
123-
void operator()(const base_device::DEVICE_GPU* dev, FPTYPE* arr, const int var, const size_t size);
122+
void operator()(FPTYPE* arr, const int var, const size_t size);
124123
};
125124

126125
template <typename FPTYPE>

source/module_base/module_device/rocm/memory_op.hip.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ __global__ void cast_memory(std::complex<FPTYPE_out>* out, const std::complex<FP
3939
}
4040

4141
template <typename FPTYPE>
42-
void resize_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
43-
FPTYPE*& arr,
42+
void resize_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE*& arr,
4443
const size_t size,
4544
const char* record_in)
4645
{

source/module_base/module_device/test/memory_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class TestModulePsiMemory : public ::testing::Test
9191
TEST_F(TestModulePsiMemory, set_memory_op_double_cpu)
9292
{
9393
std::vector<double> v_xx = xx;
94-
set_memory_double_cpu_op()(cpu_ctx, v_xx.data(), 0, xx.size());
94+
set_memory_double_cpu_op()(v_xx.data(), 0, xx.size());
9595
for (int ii = 0; ii < xx.size(); ii++)
9696
{
9797
EXPECT_EQ(v_xx[ii], 0.0);
@@ -101,7 +101,7 @@ TEST_F(TestModulePsiMemory, set_memory_op_double_cpu)
101101
TEST_F(TestModulePsiMemory, set_memory_op_complex_double_cpu)
102102
{
103103
std::vector<std::complex<double>> vz_xx = z_xx;
104-
set_memory_complex_double_cpu_op()(cpu_ctx, vz_xx.data(), 0, z_xx.size());
104+
set_memory_complex_double_cpu_op()(vz_xx.data(), 0, z_xx.size());
105105
for (int ii = 0; ii < z_xx.size(); ii++)
106106
{
107107
EXPECT_EQ(vz_xx[ii], std::complex<double>(0.0, 0.0));
@@ -175,7 +175,7 @@ TEST_F(TestModulePsiMemory, set_memory_op_double_gpu)
175175
{
176176
thrust::device_ptr<double> d_xx = thrust::device_malloc<double>(xx.size());
177177
thrust::copy(xx.begin(), xx.end(), d_xx);
178-
set_memory_double_gpu_op()(gpu_ctx, thrust::raw_pointer_cast(d_xx), 0, xx.size());
178+
set_memory_double_gpu_op()(thrust::raw_pointer_cast(d_xx), 0, xx.size());
179179
thrust::host_vector<double> h_xx(xx.size());
180180
thrust::copy(d_xx, d_xx + xx.size(), h_xx.begin());
181181
for (int ii = 0; ii < xx.size(); ii++)
@@ -188,7 +188,7 @@ TEST_F(TestModulePsiMemory, set_memory_op_complex_double_gpu)
188188
{
189189
thrust::device_ptr<std::complex<double>> dz_xx = thrust::device_malloc<std::complex<double>>(z_xx.size());
190190
thrust::copy(z_xx.begin(), z_xx.end(), dz_xx);
191-
set_memory_complex_double_gpu_op()(gpu_ctx, thrust::raw_pointer_cast(dz_xx), 0, z_xx.size());
191+
set_memory_complex_double_gpu_op()(thrust::raw_pointer_cast(dz_xx), 0, z_xx.size());
192192
thrust::host_vector<std::complex<double>> h_xx(z_xx.size());
193193
thrust::copy(dz_xx, dz_xx + z_xx.size(), h_xx.begin());
194194
for (int ii = 0; ii < z_xx.size(); ii++)

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
413413
assert(this->poolnproc == 1);
414414
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<float>(), this->nxyz);
415415
base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()(
416-
ctx,
417416
this->fft_bundle.get_auxr_3d_data<float>(),
418417
0,
419418
this->nxyz);
@@ -450,7 +449,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
450449
assert(this->poolnproc == 1);
451450
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
452451
base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()(
453-
ctx,
454452
this->fft_bundle.get_auxr_3d_data<double>(),
455453
0,
456454
this->nxyz);

source/module_elecstate/elecstate_pw.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
118118
{
119119
// denghui replaced at 20221110
120120
// ModuleBase::GlobalFunc::ZEROS(this->rho[is], this->charge->nrxx);
121-
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
121+
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
122122
if (get_xc_func_type() == 3)
123123
{
124124
// ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
125-
setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx);
125+
setmem_var_op()(this->kin_r[is], 0, this->charge->nrxx);
126126
}
127127
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
128128
{
129-
setmem_complex_op()(this->ctx, this->rhog[is], 0, this->charge->rhopw->npw);
129+
setmem_complex_op()(this->rhog[is], 0, this->charge->rhopw->npw);
130130
}
131131
}
132132

@@ -244,7 +244,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
244244
{
245245
for (int j = 0; j < 3; j++)
246246
{
247-
setmem_complex_op()(this->ctx, this->wfcr, 0, this->charge->nrxx);
247+
setmem_complex_op()(this->wfcr, 0, this->charge->nrxx);
248248

249249
meta_op()(this->ctx,
250250
ik,
@@ -280,7 +280,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
280280
resmem_complex_op()(becp, nbands * nkb, "ElecState<PW>::becp");
281281
const int nh_tot = this->ppcell->nhm * (this->ppcell->nhm + 1) / 2;
282282
resmem_var_op()(becsum, nh_tot * ucell->nat * PARAM.inp.nspin, "ElecState<PW>::becsum");
283-
setmem_var_op()(this->ctx, becsum, 0, nh_tot * ucell->nat * PARAM.inp.nspin);
283+
setmem_var_op()(becsum, 0, nh_tot * ucell->nat * PARAM.inp.nspin);
284284

285285
for (int ik = 0; ik < psi.get_nk(); ++ik)
286286
{

source/module_elecstate/elecstate_pw_cal_tau.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
99
ModuleBase::TITLE("ElecStatePW", "cal_tau");
1010
for(int is=0; is<PARAM.inp.nspin; is++)
1111
{
12-
setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx);
12+
setmem_var_op()(this->kin_r[is], 0, this->charge->nrxx);
1313
}
1414

1515
for (int ik = 0; ik < psi.get_nk(); ++ik)
@@ -31,7 +31,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
3131
// kinetic energy density
3232
for (int j = 0; j < 3; j++)
3333
{
34-
setmem_complex_op()(this->ctx, this->wfcr, 0, this->charge->nrxx);
34+
setmem_complex_op()(this->wfcr, 0, this->charge->nrxx);
3535

3636
meta_op()(this->ctx,
3737
ik,

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
1616
const int nspin = PARAM.inp.nspin;
1717
for (int is = 0; is < nspin; is++)
1818
{
19-
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
19+
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
2020
}
2121

2222
if (GlobalV::MY_STOGROUP == 0)

0 commit comments

Comments
 (0)