Skip to content

Commit 97aa0c8

Browse files
committed
Remove ctx parameters in sync_memory_op
1 parent d485c50 commit 97aa0c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+355
-448
lines changed

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ class PyDiagoCG
153153
const int nrow = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
154154
const int nbands = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
155155
syncmem_z2z_h2h_op()(
156-
this->ctx,
157-
this->ctx,
158156
spsi_out.data<std::complex<double>>(),
159157
psi_in.data<std::complex<double>>(),
160158
static_cast<size_t>(nrow * nbands)

python/pyabacus/src/hsolver/py_diago_david.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class PyDiagoDavid
135135
const int nrow,
136136
const int nbands
137137
) {
138-
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
138+
syncmem_op()(spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
139139
};
140140

141141
obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>(

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
7575

7676
auto* swap = new T[notconv * nbase_x];
7777
auto* target = new T[notconv * nbase_x];
78-
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
78+
syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x);
7979
if (base_device::get_current_precision(swap) == "single")
8080
{
8181
MPI_Reduce(swap,
@@ -97,8 +97,8 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
9797
diag_comm);
9898
}
9999

100-
syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
101-
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);
100+
syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x);
101+
syncmem_complex_op()(swap, scc + nbase * nbase_x, notconv * nbase_x);
102102

103103
if (base_device::get_current_precision(swap) == "single")
104104
{
@@ -121,7 +121,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
121121
diag_comm);
122122
}
123123

124-
syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
124+
syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x);
125125
delete[] swap;
126126
delete[] target;
127127
}

source/module_base/kernels/test/math_op_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,9 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
310310
resmem_var_op()(d_p, p.size());
311311
resmem_var_op()(d_ylm, ylm.size());
312312

313-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_g, g.data(), g.size());
314-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_p, p.data(), p.size());
315-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_ylm, ylm.data(), ylm.size());
313+
syncmem_var_h2d_op()(d_g, g.data(), g.size());
314+
syncmem_var_h2d_op()(d_p, p.data(), p.size());
315+
syncmem_var_h2d_op()(d_ylm, ylm.data(), ylm.size());
316316

317317
ModuleBase::cal_ylm_real_op<double, base_device::DEVICE_GPU>()(gpu_ctx,
318318
ng,
@@ -326,7 +326,7 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
326326
d_p,
327327
d_ylm);
328328

329-
syncmem_var_d2h_op()(cpu_ctx, gpu_ctx, ylm.data(), d_ylm, ylm.size());
329+
syncmem_var_d2h_op()(ylm.data(), d_ylm, ylm.size());
330330

331331
for (int ii = 0; ii < ylm.size(); ii++) {
332332
EXPECT_LT(fabs(ylm[ii] - expected_ylm[ii]), 6e-5);

source/module_base/math_chebyshev.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ REAL Chebyshev<REAL, Device>::ddot_real(const std::complex<REAL>* psi_L,
131131
REAL* dot_device = nullptr;
132132
resmem_var_op()(dot_device, 1);
133133
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
134-
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result, dot_device, 1);
134+
syncmem_var_d2h_op()(&result, dot_device, 1);
135135
delmem_var_op()(this->ctx, dot_device);
136136
}
137137
else
@@ -146,7 +146,7 @@ REAL Chebyshev<REAL, Device>::ddot_real(const std::complex<REAL>* psi_L,
146146
int dim2 = 2 * N;
147147
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
148148
REAL result_temp = 0;
149-
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result_temp, dot_device, 1);
149+
syncmem_var_d2h_op()(&result_temp, dot_device, 1);
150150
result += result_temp;
151151
pL += 2 * LDA;
152152
pR += 2 * LDA;
@@ -211,7 +211,7 @@ void Chebyshev<REAL, Device>::calcoef_real(std::function<REAL(REAL)> fun)
211211

212212
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
213213
{
214-
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, coef_real, coefr_cpu, norder);
214+
syncmem_var_h2d_op()(coef_real, coefr_cpu, norder);
215215
}
216216

217217
getcoef_real = true;
@@ -301,7 +301,7 @@ void Chebyshev<REAL, Device>::calcoef_complex(std::function<std::complex<REAL>(s
301301
}
302302
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
303303
{
304-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
304+
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
305305
}
306306

307307
getcoef_complex = true;
@@ -392,7 +392,7 @@ void Chebyshev<REAL, Device>::calcoef_pair(std::function<REAL(REAL)> fun1, std::
392392

393393
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
394394
{
395-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
395+
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
396396
}
397397

398398
getcoef_complex = true;
@@ -431,7 +431,7 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
431431
resmem_complex_op()(arrayn, ndmxt);
432432
resmem_complex_op()(arrayn_1, ndmxt);
433433

434-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
434+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
435435
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);
436436

437437
funA(arrayn_1, arrayn, m);
@@ -500,7 +500,7 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
500500
resmem_complex_op()(arrayn, ndmxt);
501501
resmem_complex_op()(arrayn_1, ndmxt);
502502

503-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
503+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
504504

505505
funA(arrayn_1, arrayn, m);
506506

@@ -553,7 +553,7 @@ void Chebyshev<REAL, Device>::calpolyvec_complex(
553553
std::complex<REAL>*tmpin = wavein, *tmpout = arrayn_1;
554554
for (int i = 0; i < m; ++i)
555555
{
556-
memcpy_complex_op()(this->ctx, this->ctx, tmpout, tmpin, N);
556+
memcpy_complex_op()(tmpout, tmpin, N);
557557
// ModuleBase::GlobalFunc::DCOPY(tmpin, tmpout, N);
558558
tmpin += LDA;
559559
tmpout += LDA;
@@ -599,7 +599,7 @@ void Chebyshev<REAL, Device>::tracepolyA(
599599
resmem_complex_op()(arrayn, ndmxt);
600600
resmem_complex_op()(arrayn_1, ndmxt);
601601

602-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
602+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
603603
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);
604604

605605
funA(arrayn_1, arrayn, m);
@@ -673,7 +673,7 @@ bool Chebyshev<REAL, Device>::checkconverge(
673673
resmem_complex_op()(arrayn, LDA);
674674
resmem_complex_op()(arrayn_1, LDA);
675675

676-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, N);
676+
memcpy_complex_op()(arrayn_1, wavein, N);
677677
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N);
678678

679679
if (tmin == tmax)

source/module_base/module_device/cuda/memory_op.cu

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE* arr,
8787

8888
template <typename FPTYPE>
8989
void synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>::operator()(
90-
const base_device::DEVICE_CPU* dev_out,
91-
const base_device::DEVICE_GPU* dev_in,
9290
FPTYPE* arr_out,
9391
const FPTYPE* arr_in,
9492
const size_t size)
@@ -98,8 +96,6 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_
9896

9997
template <typename FPTYPE>
10098
void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_CPU>::operator()(
101-
const base_device::DEVICE_GPU* dev_out,
102-
const base_device::DEVICE_CPU* dev_in,
10399
FPTYPE* arr_out,
104100
const FPTYPE* arr_in,
105101
const size_t size)
@@ -109,8 +105,6 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_
109105

110106
template <typename FPTYPE>
111107
void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_GPU>::operator()(
112-
const base_device::DEVICE_GPU* dev_out,
113-
const base_device::DEVICE_GPU* dev_in,
114108
FPTYPE* arr_out,
115109
const FPTYPE* arr_in,
116110
const size_t size)
@@ -150,9 +144,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_devic
150144
// No need to cast the memory if the data types are the same.
151145
if (std::is_same<FPTYPE_out, FPTYPE_in>::value)
152146
{
153-
synchronize_memory_op<FPTYPE_out, base_device::DEVICE_GPU, base_device::DEVICE_CPU>()(dev_out,
154-
dev_in,
155-
arr_out,
147+
synchronize_memory_op<FPTYPE_out, base_device::DEVICE_GPU, base_device::DEVICE_CPU>()(arr_out,
156148
reinterpret_cast<const FPTYPE_out*>(arr_in),
157149
size);
158150
return;
@@ -178,9 +170,7 @@ struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_CPU, base_devic
178170
// No need to cast the memory if the data types are the same.
179171
if (std::is_same<FPTYPE_out, FPTYPE_in>::value)
180172
{
181-
synchronize_memory_op<FPTYPE_out, base_device::DEVICE_CPU, base_device::DEVICE_GPU>()(dev_out,
182-
dev_in,
183-
arr_out,
173+
synchronize_memory_op<FPTYPE_out, base_device::DEVICE_CPU, base_device::DEVICE_GPU>()(arr_out,
184174
reinterpret_cast<const FPTYPE_out*>(arr_in),
185175
size);
186176
return;

source/module_base/module_device/memory_op.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ struct set_memory_op<FPTYPE, base_device::DEVICE_CPU>
5858
template <typename FPTYPE>
5959
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_CPU>
6060
{
61-
void operator()(const base_device::DEVICE_CPU* dev_out,
62-
const base_device::DEVICE_CPU* dev_in,
63-
FPTYPE* arr_out,
61+
void operator()(FPTYPE* arr_out,
6462
const FPTYPE* arr_in,
6563
const size_t size)
6664
{
@@ -174,9 +172,7 @@ struct set_memory_op<FPTYPE, base_device::DEVICE_GPU>
174172
template <typename FPTYPE>
175173
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_GPU>
176174
{
177-
void operator()(const base_device::DEVICE_GPU* dev_out,
178-
const base_device::DEVICE_GPU* dev_in,
179-
FPTYPE* arr_out,
175+
void operator()(FPTYPE* arr_out,
180176
const FPTYPE* arr_in,
181177
const size_t size)
182178
{
@@ -186,9 +182,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVIC
186182
template <typename FPTYPE>
187183
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_CPU>
188184
{
189-
void operator()(const base_device::DEVICE_GPU* dev_out,
190-
const base_device::DEVICE_CPU* dev_in,
191-
FPTYPE* arr_out,
185+
void operator()(FPTYPE* arr_out,
192186
const FPTYPE* arr_in,
193187
const size_t size)
194188
{
@@ -198,9 +192,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVIC
198192
template <typename FPTYPE>
199193
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>
200194
{
201-
void operator()(const base_device::DEVICE_CPU* dev_out,
202-
const base_device::DEVICE_GPU* dev_in,
203-
FPTYPE* arr_out,
195+
void operator()(FPTYPE* arr_out,
204196
const FPTYPE* arr_in,
205197
const size_t size)
206198
{

source/module_base/module_device/memory_op.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,12 @@ struct synchronize_memory_op
4646
/// @brief memcpy for multi-device
4747
///
4848
/// Input Parameters
49-
/// \param dev_out : the type of computing device of arr_out
50-
/// \param dev_in : the type of computing device of arr_in
5149
/// \param arr_in : input array
5250
/// \param size : array size
5351
///
5452
/// Output Parameters
5553
/// \param arr_out : output array initialized by the input array
56-
void operator()(const Device_out* dev_out,
57-
const Device_in* dev_in,
58-
FPTYPE* arr_out,
54+
void operator()(FPTYPE* arr_out,
5955
const FPTYPE* arr_in,
6056
const size_t size);
6157
};
@@ -125,27 +121,21 @@ struct set_memory_op<FPTYPE, base_device::DEVICE_GPU>
125121
template <typename FPTYPE>
126122
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>
127123
{
128-
void operator()(const base_device::DEVICE_CPU* dev_out,
129-
const base_device::DEVICE_GPU* dev_in,
130-
FPTYPE* arr_out,
124+
void operator()(FPTYPE* arr_out,
131125
const FPTYPE* arr_in,
132126
const size_t size);
133127
};
134128
template <typename FPTYPE>
135129
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_CPU>
136130
{
137-
void operator()(const base_device::DEVICE_GPU* dev_out,
138-
const base_device::DEVICE_CPU* dev_in,
139-
FPTYPE* arr_out,
131+
void operator()(FPTYPE* arr_out,
140132
const FPTYPE* arr_in,
141133
const size_t size);
142134
};
143135
template <typename FPTYPE>
144136
struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_GPU>
145137
{
146-
void operator()(const base_device::DEVICE_GPU* dev_out,
147-
const base_device::DEVICE_GPU* dev_in,
148-
FPTYPE* arr_out,
138+
void operator()(FPTYPE* arr_out,
149139
const FPTYPE* arr_in,
150140
const size_t size);
151141
};

source/module_base/module_device/rocm/memory_op.hip.cu

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ void resize_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE*& arr,
5151
}
5252

5353
template <typename FPTYPE>
54-
void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
55-
FPTYPE* arr,
54+
void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE* arr,
5655
const int var,
5756
const size_t size)
5857
{
@@ -61,8 +60,6 @@ void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_devic
6160

6261
template <typename FPTYPE>
6362
void synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>::operator()(
64-
const base_device::DEVICE_CPU* dev_out,
65-
const base_device::DEVICE_GPU* dev_in,
6663
FPTYPE* arr_out,
6764
const FPTYPE* arr_in,
6865
const size_t size)
@@ -72,8 +69,6 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_
7269

7370
template <typename FPTYPE>
7471
void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_CPU>::operator()(
75-
const base_device::DEVICE_GPU* dev_out,
76-
const base_device::DEVICE_CPU* dev_in,
7772
FPTYPE* arr_out,
7873
const FPTYPE* arr_in,
7974
const size_t size)
@@ -83,8 +78,6 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_
8378

8479
template <typename FPTYPE>
8580
void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_GPU>::operator()(
86-
const base_device::DEVICE_GPU* dev_out,
87-
const base_device::DEVICE_GPU* dev_in,
8881
FPTYPE* arr_out,
8982
const FPTYPE* arr_in,
9083
const size_t size)

0 commit comments

Comments
 (0)