Skip to content

Commit 248a9e2

Browse files
committed
Remove ctx parameters in cast_memory_op
1 parent 97aa0c8 commit 248a9e2

File tree

16 files changed

+54
-102
lines changed

16 files changed

+54
-102
lines changed

source/module_base/module_device/cuda/memory_op.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_
115115
template <typename FPTYPE_out, typename FPTYPE_in>
116116
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_GPU>
117117
{
118-
void operator()(const base_device::DEVICE_GPU* dev_out,
119-
const base_device::DEVICE_GPU* dev_in,
120-
FPTYPE_out* arr_out,
118+
void operator()(FPTYPE_out* arr_out,
121119
const FPTYPE_in* arr_in,
122120
const size_t size)
123121
{
@@ -134,9 +132,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
134132

135133
template <typename FPTYPE_out, typename FPTYPE_in>
136134
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_CPU> {
137-
void operator()(const base_device::DEVICE_GPU* dev_out,
138-
const base_device::DEVICE_CPU* dev_in,
139-
FPTYPE_out* arr_out,
135+
void operator()(FPTYPE_out* arr_out,
140136
const FPTYPE_in* arr_in,
141137
const size_t size) {
142138

@@ -161,9 +157,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
161157

162158
template <typename FPTYPE_out, typename FPTYPE_in>
163159
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_device::DEVICE_GPU> {
164-
void operator()(const base_device::DEVICE_CPU* dev_out,
165-
const base_device::DEVICE_GPU* dev_in,
166-
FPTYPE_out* arr_out,
160+
void operator()(FPTYPE_out* arr_out,
167161
const FPTYPE_in* arr_in,
168162
const size_t size) {
169163
if (size == 0) {return;}

source/module_base/module_device/memory_op.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVIC
7373
template <typename FPTYPE_out, typename FPTYPE_in>
7474
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_device::DEVICE_CPU>
7575
{
76-
void operator()(const base_device::DEVICE_CPU* dev_out,
77-
const base_device::DEVICE_CPU* dev_in,
78-
FPTYPE_out* arr_out,
76+
void operator()(FPTYPE_out* arr_out,
7977
const FPTYPE_in* arr_in,
8078
const size_t size)
8179
{
@@ -202,9 +200,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVIC
202200
template <typename FPTYPE_out, typename FPTYPE_in>
203201
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_GPU>
204202
{
205-
void operator()(const base_device::DEVICE_GPU* dev_out,
206-
const base_device::DEVICE_GPU* dev_in,
207-
FPTYPE_out* arr_out,
203+
void operator()(FPTYPE_out* arr_out,
208204
const FPTYPE_in* arr_in,
209205
const size_t size)
210206
{
@@ -214,9 +210,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
214210
template <typename FPTYPE_out, typename FPTYPE_in>
215211
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_CPU>
216212
{
217-
void operator()(const base_device::DEVICE_GPU* dev_out,
218-
const base_device::DEVICE_CPU* dev_in,
219-
FPTYPE_out* arr_out,
213+
void operator()(FPTYPE_out* arr_out,
220214
const FPTYPE_in* arr_in,
221215
const size_t size)
222216
{
@@ -226,9 +220,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
226220
template <typename FPTYPE_out, typename FPTYPE_in>
227221
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_device::DEVICE_GPU>
228222
{
229-
void operator()(const base_device::DEVICE_CPU* dev_out,
230-
const base_device::DEVICE_GPU* dev_in,
231-
FPTYPE_out* arr_out,
223+
void operator()(FPTYPE_out* arr_out,
232224
const FPTYPE_in* arr_in,
233225
const size_t size)
234226
{

source/module_base/module_device/memory_op.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ struct cast_memory_op
6969
///
7070
/// Output Parameters
7171
/// \param arr_out : output array initialized by the input array
72-
void operator()(const Device_out* dev_out,
73-
const Device_in* dev_in,
74-
FPTYPE_out* arr_out,
72+
void operator()(FPTYPE_out* arr_out,
7573
const FPTYPE_in* arr_in,
7674
const size_t size);
7775
};

source/module_base/module_device/rocm/memory_op.hip.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_
8787

8888
template <typename FPTYPE_out, typename FPTYPE_in>
8989
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_GPU> {
90-
void operator()(const base_device::DEVICE_GPU* dev_out,
91-
const base_device::DEVICE_GPU* dev_in,
92-
FPTYPE_out* arr_out,
90+
void operator()(FPTYPE_out* arr_out,
9391
const FPTYPE_in* arr_in,
9492
const size_t size) {
9593

@@ -102,9 +100,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
102100

103101
template <typename FPTYPE_out, typename FPTYPE_in>
104102
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_CPU> {
105-
void operator()(const base_device::DEVICE_GPU* dev_out,
106-
const base_device::DEVICE_CPU* dev_in,
107-
FPTYPE_out* arr_out,
103+
void operator()(FPTYPE_out* arr_out,
108104
const FPTYPE_in* arr_in,
109105
const size_t size) {
110106

@@ -131,9 +127,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
131127

132128
template <typename FPTYPE_out, typename FPTYPE_in>
133129
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_device::DEVICE_GPU> {
134-
void operator()(const base_device::DEVICE_CPU* dev_out,
135-
const base_device::DEVICE_GPU* dev_in,
136-
FPTYPE_out* arr_out,
130+
void operator()(FPTYPE_out* arr_out,
137131
const FPTYPE_in* arr_in,
138132
const size_t size) {
139133

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void PW_Basis_K:: initparameters(
100100
if (this->device == "gpu") {
101101
if (this->precision == "single") {
102102
resmem_sd_op()(this->s_kvec_c, this->nks * 3);
103-
castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
103+
castmem_d2s_h2d_op()(this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
104104
}
105105
resmem_dd_op()(this->d_kvec_c, this->nks * 3);
106106
syncmem_d2d_h2d_op()(this->d_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
@@ -109,7 +109,7 @@ void PW_Basis_K:: initparameters(
109109
#endif
110110
if (this->precision == "single") {
111111
resmem_sh_op()(this->s_kvec_c, this->nks * 3);
112-
castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
112+
castmem_d2s_h2h_op()(this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
113113
}
114114
this->d_kvec_c = reinterpret_cast<double *>(&this->kvec_c[0][0]);
115115
// There's no need to allocate double pointers while in a CPU environment.
@@ -249,8 +249,8 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h
249249
if (this->precision == "single") {
250250
resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks);
251251
resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3);
252-
castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks);
253-
castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_gcar, reinterpret_cast<double *>(&this->gcar[0][0]), this->npwk_max * this->nks * 3);
252+
castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks);
253+
castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast<double *>(&this->gcar[0][0]), this->npwk_max * this->nks * 3);
254254
}
255255
else {
256256
resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks);
@@ -264,8 +264,8 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h
264264
if (this->precision == "single") {
265265
resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2");
266266
resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar");
267-
castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks);
268-
castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_gcar, reinterpret_cast<double *>(&this->gcar[0][0]), this->npwk_max * this->nks * 3);
267+
castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks);
268+
castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast<double *>(&this->gcar[0][0]), this->npwk_max * this->nks * 3);
269269
}
270270
else {
271271
this->d_gcar = reinterpret_cast<double *>(&this->gcar[0][0]);

source/module_elecstate/elecstate_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
142142
{
143143
for (int ii = 0; ii < PARAM.inp.nspin; ii++)
144144
{
145-
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
145+
castmem_var_d2h_op()(this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
146146
if (get_xc_func_type() == 3)
147147
{
148-
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
148+
castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
149149
}
150150
}
151151
}

source/module_elecstate/elecstate_pw_cal_tau.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
5252
}
5353
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
5454
for (int ii = 0; ii < PARAM.inp.nspin; ii++) {
55-
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
55+
castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
5656
}
5757
}
5858
this->parallelK();

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
2828
}
2929
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
3030
for (int ii = 0; ii < nspin; ii++) {
31-
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
31+
castmem_var_d2h_op()(this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
3232
}
3333
}
3434
this->parallelK();

source/module_elecstate/potentials/potential_new.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,10 @@ void Potential::update_from_charge(const Charge*const chg, const UnitCell*const
181181

182182
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
183183
if (PARAM.inp.precision == "single") {
184-
castmem_d2s_h2d_op()(gpu_ctx,
185-
cpu_ctx,
186-
s_veff_smooth,
184+
castmem_d2s_h2d_op()(s_veff_smooth,
187185
this->veff_smooth.c,
188186
this->veff_smooth.nr * this->veff_smooth.nc);
189-
castmem_d2s_h2d_op()(gpu_ctx,
190-
cpu_ctx,
191-
s_vofk_smooth,
187+
castmem_d2s_h2d_op()(s_vofk_smooth,
192188
this->vofk_smooth.c,
193189
this->vofk_smooth.nr * this->vofk_smooth.nc);
194190
}
@@ -203,14 +199,10 @@ void Potential::update_from_charge(const Charge*const chg, const UnitCell*const
203199
}
204200
else {
205201
if (PARAM.inp.precision == "single") {
206-
castmem_d2s_h2h_op()(cpu_ctx,
207-
cpu_ctx,
208-
s_veff_smooth,
202+
castmem_d2s_h2h_op()(s_veff_smooth,
209203
this->veff_smooth.c,
210204
this->veff_smooth.nr * this->veff_smooth.nc);
211-
castmem_d2s_h2h_op()(cpu_ctx,
212-
cpu_ctx,
213-
s_vofk_smooth,
205+
castmem_d2s_h2h_op()(s_vofk_smooth,
214206
this->vofk_smooth.c,
215207
this->vofk_smooth.nr * this->vofk_smooth.nc);
216208
}

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,7 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep)
646646
// 4) Transfer data from GPU to CPU
647647
if (this->device == base_device::GpuDevice)
648648
{
649-
castmem_2d_d2h_op()(this->psi[0].get_device(),
650-
this->kspw_psi[0].get_device(),
651-
this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
649+
castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
652650
this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(),
653651
this->psi[0].size());
654652
}

0 commit comments

Comments
 (0)