Skip to content

Commit 9bfb2ee

Browse files
Critsium-xypre-commit-ci-lite[bot]
authored andcommitted
[Refactor] Remove all ctx parameters in memory operators (deepmodeling#5862)
* Remove all ctx parameters in resize_memory_op * Small bug fix * [pre-commit.ci lite] apply automatic fixes * Remove all ctx parameters in set_memory_op * Remove ctx parameters in sync_memory_op * Remove ctx parameters in cast_memory_op * Remove ctx parameter in delete_memory_op --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 08d9d57 commit 9bfb2ee

File tree

80 files changed

+1382
-1534
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1382
-1534
lines changed

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ class PyDiagoCG
153153
const int nrow = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
154154
const int nbands = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
155155
syncmem_z2z_h2h_op()(
156-
this->ctx,
157-
this->ctx,
158156
spsi_out.data<std::complex<double>>(),
159157
psi_in.data<std::complex<double>>(),
160158
static_cast<size_t>(nrow * nbands)

python/pyabacus/src/hsolver/py_diago_david.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class PyDiagoDavid
135135
const int nrow,
136136
const int nbands
137137
) {
138-
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
138+
syncmem_op()(spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
139139
};
140140

141141
obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>(

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
7575

7676
auto* swap = new T[notconv * nbase_x];
7777
auto* target = new T[notconv * nbase_x];
78-
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
78+
syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x);
7979
if (base_device::get_current_precision(swap) == "single")
8080
{
8181
MPI_Reduce(swap,
@@ -97,8 +97,8 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
9797
diag_comm);
9898
}
9999

100-
syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
101-
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);
100+
syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x);
101+
syncmem_complex_op()(swap, scc + nbase * nbase_x, notconv * nbase_x);
102102

103103
if (base_device::get_current_precision(swap) == "single")
104104
{
@@ -121,7 +121,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
121121
diag_comm);
122122
}
123123

124-
syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
124+
syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x);
125125
delete[] swap;
126126
delete[] target;
127127
}

source/module_base/kernels/test/math_op_test.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
306306
std::vector<double> ylm(expected_ylm.size(), 0.0);
307307
double * d_ylm = nullptr, * d_g = nullptr, * d_p = nullptr;
308308

309-
resmem_var_op()(gpu_ctx, d_g, g.size());
310-
resmem_var_op()(gpu_ctx, d_p, p.size());
311-
resmem_var_op()(gpu_ctx, d_ylm, ylm.size());
309+
resmem_var_op()(d_g, g.size());
310+
resmem_var_op()(d_p, p.size());
311+
resmem_var_op()(d_ylm, ylm.size());
312312

313-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_g, g.data(), g.size());
314-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_p, p.data(), p.size());
315-
syncmem_var_h2d_op()(gpu_ctx, cpu_ctx, d_ylm, ylm.data(), ylm.size());
313+
syncmem_var_h2d_op()(d_g, g.data(), g.size());
314+
syncmem_var_h2d_op()(d_p, p.data(), p.size());
315+
syncmem_var_h2d_op()(d_ylm, ylm.data(), ylm.size());
316316

317317
ModuleBase::cal_ylm_real_op<double, base_device::DEVICE_GPU>()(gpu_ctx,
318318
ng,
@@ -326,15 +326,15 @@ TEST_F(TestModuleBaseMathMultiDevice, cal_ylm_real_op_gpu)
326326
d_p,
327327
d_ylm);
328328

329-
syncmem_var_d2h_op()(cpu_ctx, gpu_ctx, ylm.data(), d_ylm, ylm.size());
329+
syncmem_var_d2h_op()(ylm.data(), d_ylm, ylm.size());
330330

331331
for (int ii = 0; ii < ylm.size(); ii++) {
332332
EXPECT_LT(fabs(ylm[ii] - expected_ylm[ii]), 6e-5);
333333
}
334334

335-
delmem_var_op()(gpu_ctx, d_g);
336-
delmem_var_op()(gpu_ctx, d_p);
337-
delmem_var_op()(gpu_ctx, d_ylm);
335+
delmem_var_op()(d_g);
336+
delmem_var_op()(d_p);
337+
delmem_var_op()(d_ylm);
338338
}
339339

340340
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

source/module_base/math_chebyshev.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ Chebyshev<REAL, Device>::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord
6363
coefc_cpu = new std::complex<REAL>[norder];
6464
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
6565
{
66-
resmem_var_op()(this->ctx, this->coef_real, norder);
67-
resmem_complex_op()(this->ctx, this->coef_complex, norder);
66+
resmem_var_op()(this->coef_real, norder);
67+
resmem_complex_op()(this->coef_complex, norder);
6868
}
6969
else
7070
{
@@ -84,8 +84,8 @@ Chebyshev<REAL, Device>::~Chebyshev()
8484
delete[] polytrace;
8585
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
8686
{
87-
delmem_var_op()(this->ctx, this->coef_real);
88-
delmem_complex_op()(this->ctx, this->coef_complex);
87+
delmem_var_op()(this->coef_real);
88+
delmem_complex_op()(this->coef_complex);
8989
}
9090
else
9191
{
@@ -129,29 +129,29 @@ REAL Chebyshev<REAL, Device>::ddot_real(const std::complex<REAL>* psi_L,
129129
pL = (REAL*)psi_L;
130130
pR = (REAL*)psi_R;
131131
REAL* dot_device = nullptr;
132-
resmem_var_op()(this->ctx, dot_device, 1);
132+
resmem_var_op()(dot_device, 1);
133133
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
134-
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result, dot_device, 1);
135-
delmem_var_op()(this->ctx, dot_device);
134+
syncmem_var_d2h_op()(&result, dot_device, 1);
135+
delmem_var_op()(dot_device);
136136
}
137137
else
138138
{
139139
REAL *pL, *pR;
140140
pL = (REAL*)psi_L;
141141
pR = (REAL*)psi_R;
142142
REAL* dot_device = nullptr;
143-
resmem_var_op()(this->ctx, dot_device, 1);
143+
resmem_var_op()(dot_device, 1);
144144
for (int i = 0; i < m; ++i)
145145
{
146146
int dim2 = 2 * N;
147147
container::kernels::blas_dot<REAL, ct_Device>()(dim2, pL, 1, pR, 1, dot_device);
148148
REAL result_temp = 0;
149-
syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result_temp, dot_device, 1);
149+
syncmem_var_d2h_op()(&result_temp, dot_device, 1);
150150
result += result_temp;
151151
pL += 2 * LDA;
152152
pR += 2 * LDA;
153153
}
154-
delmem_var_op()(this->ctx, dot_device);
154+
delmem_var_op()(dot_device);
155155
}
156156
return result;
157157
}
@@ -211,7 +211,7 @@ void Chebyshev<REAL, Device>::calcoef_real(std::function<REAL(REAL)> fun)
211211

212212
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
213213
{
214-
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, coef_real, coefr_cpu, norder);
214+
syncmem_var_h2d_op()(coef_real, coefr_cpu, norder);
215215
}
216216

217217
getcoef_real = true;
@@ -301,7 +301,7 @@ void Chebyshev<REAL, Device>::calcoef_complex(std::function<std::complex<REAL>(s
301301
}
302302
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
303303
{
304-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
304+
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
305305
}
306306

307307
getcoef_complex = true;
@@ -392,7 +392,7 @@ void Chebyshev<REAL, Device>::calcoef_pair(std::function<REAL(REAL)> fun1, std::
392392

393393
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
394394
{
395-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder);
395+
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
396396
}
397397

398398
getcoef_complex = true;
@@ -427,17 +427,17 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
427427
ndmxt = LDA * m;
428428
}
429429

430-
resmem_complex_op()(this->ctx, arraynp1, ndmxt);
431-
resmem_complex_op()(this->ctx, arrayn, ndmxt);
432-
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
430+
resmem_complex_op()(arraynp1, ndmxt);
431+
resmem_complex_op()(arrayn, ndmxt);
432+
resmem_complex_op()(arrayn_1, ndmxt);
433433

434-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
434+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
435435
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);
436436

437437
funA(arrayn_1, arrayn, m);
438438

439439
// 0- & 1-st order
440-
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
440+
setmem_complex_op()(waveout, 0, ndmxt);
441441
std::complex<REAL> coef0 = std::complex<REAL>(coefr_cpu[0], 0);
442442
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coef0, arrayn_1, 1, waveout, 1);
443443
std::complex<REAL> coef1 = std::complex<REAL>(coefr_cpu[1], 0);
@@ -462,9 +462,9 @@ void Chebyshev<REAL, Device>::calfinalvec_real(
462462
arrayn = arraynp1;
463463
arraynp1 = tem;
464464
}
465-
delmem_complex_op()(this->ctx, arraynp1);
466-
delmem_complex_op()(this->ctx, arrayn);
467-
delmem_complex_op()(this->ctx, arrayn_1);
465+
delmem_complex_op()(arraynp1);
466+
delmem_complex_op()(arrayn);
467+
delmem_complex_op()(arrayn_1);
468468
return;
469469
}
470470

@@ -496,16 +496,16 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
496496
ndmxt = LDA * m;
497497
}
498498

499-
resmem_complex_op()(this->ctx, arraynp1, ndmxt);
500-
resmem_complex_op()(this->ctx, arrayn, ndmxt);
501-
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
499+
resmem_complex_op()(arraynp1, ndmxt);
500+
resmem_complex_op()(arrayn, ndmxt);
501+
resmem_complex_op()(arrayn_1, ndmxt);
502502

503-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
503+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
504504

505505
funA(arrayn_1, arrayn, m);
506506

507507
// 0- & 1-st order
508-
setmem_complex_op()(this->ctx, waveout, 0, ndmxt);
508+
setmem_complex_op()(waveout, 0, ndmxt);
509509
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[0], arrayn_1, 1, waveout, 1);
510510
container::kernels::blas_axpy<std::complex<REAL>, ct_Device>()(ndmxt, &coefc_cpu[1], arrayn, 1, waveout, 1);
511511
// for (int i = 0; i < ndmxt; ++i)
@@ -527,9 +527,9 @@ void Chebyshev<REAL, Device>::calfinalvec_complex(
527527
arrayn = arraynp1;
528528
arraynp1 = tem;
529529
}
530-
delmem_complex_op()(this->ctx, arraynp1);
531-
delmem_complex_op()(this->ctx, arrayn);
532-
delmem_complex_op()(this->ctx, arrayn_1);
530+
delmem_complex_op()(arraynp1);
531+
delmem_complex_op()(arrayn);
532+
delmem_complex_op()(arrayn_1);
533533
return;
534534
}
535535

@@ -553,7 +553,7 @@ void Chebyshev<REAL, Device>::calpolyvec_complex(
553553
std::complex<REAL>*tmpin = wavein, *tmpout = arrayn_1;
554554
for (int i = 0; i < m; ++i)
555555
{
556-
memcpy_complex_op()(this->ctx, this->ctx, tmpout, tmpin, N);
556+
memcpy_complex_op()(tmpout, tmpin, N);
557557
// ModuleBase::GlobalFunc::DCOPY(tmpin, tmpout, N);
558558
tmpin += LDA;
559559
tmpout += LDA;
@@ -595,11 +595,11 @@ void Chebyshev<REAL, Device>::tracepolyA(
595595
ndmxt = LDA * m;
596596
}
597597

598-
resmem_complex_op()(this->ctx, arraynp1, ndmxt);
599-
resmem_complex_op()(this->ctx, arrayn, ndmxt);
600-
resmem_complex_op()(this->ctx, arrayn_1, ndmxt);
598+
resmem_complex_op()(arraynp1, ndmxt);
599+
resmem_complex_op()(arrayn, ndmxt);
600+
resmem_complex_op()(arrayn_1, ndmxt);
601601

602-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt);
602+
memcpy_complex_op()(arrayn_1, wavein, ndmxt);
603603
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt);
604604

605605
funA(arrayn_1, arrayn, m);
@@ -618,9 +618,9 @@ void Chebyshev<REAL, Device>::tracepolyA(
618618
arraynp1 = tem;
619619
}
620620

621-
delmem_complex_op()(this->ctx, arraynp1);
622-
delmem_complex_op()(this->ctx, arrayn);
623-
delmem_complex_op()(this->ctx, arrayn_1);
621+
delmem_complex_op()(arraynp1);
622+
delmem_complex_op()(arrayn);
623+
delmem_complex_op()(arrayn_1);
624624
return;
625625
}
626626

@@ -669,11 +669,11 @@ bool Chebyshev<REAL, Device>::checkconverge(
669669
std::complex<REAL>* arrayn = nullptr;
670670
std::complex<REAL>* arrayn_1 = nullptr;
671671

672-
resmem_complex_op()(this->ctx, arraynp1, LDA);
673-
resmem_complex_op()(this->ctx, arrayn, LDA);
674-
resmem_complex_op()(this->ctx, arrayn_1, LDA);
672+
resmem_complex_op()(arraynp1, LDA);
673+
resmem_complex_op()(arrayn, LDA);
674+
resmem_complex_op()(arrayn_1, LDA);
675675

676-
memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, N);
676+
memcpy_complex_op()(arrayn_1, wavein, N);
677677
// ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N);
678678

679679
if (tmin == tmax)
@@ -754,9 +754,9 @@ bool Chebyshev<REAL, Device>::checkconverge(
754754
arraynp1 = tem;
755755
}
756756

757-
delmem_complex_op()(this->ctx, arraynp1);
758-
delmem_complex_op()(this->ctx, arrayn);
759-
delmem_complex_op()(this->ctx, arrayn_1);
757+
delmem_complex_op()(arraynp1);
758+
delmem_complex_op()(arrayn);
759+
delmem_complex_op()(arrayn_1);
760760
return converge;
761761
}
762762

source/module_base/math_ylmreal.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ void YlmReal::Ylm_Real(Device * ctx, const int lmax2, const int ng, const FPTYPE
327327
ModuleBase::WARNING_QUIT("YLM_REAL","l>30 or l<0");
328328
}
329329
FPTYPE * p = nullptr, * phi = nullptr, * cost = nullptr;
330-
resmem_var_op()(ctx, p, (lmax + 1) * (lmax + 1) * ng, "YlmReal::Ylm_Real");
330+
resmem_var_op()(p, (lmax + 1) * (lmax + 1) * ng, "YlmReal::Ylm_Real");
331331

332332
cal_ylm_real_op()(
333333
ctx,
@@ -342,9 +342,9 @@ void YlmReal::Ylm_Real(Device * ctx, const int lmax2, const int ng, const FPTYPE
342342
p,
343343
ylm);
344344

345-
delmem_var_op()(ctx, p);
346-
delmem_var_op()(ctx, phi);
347-
delmem_var_op()(ctx, cost);
345+
delmem_var_op()(p);
346+
delmem_var_op()(phi);
347+
delmem_var_op()(cost);
348348
} // end subroutine ylmr2
349349

350350
//==========================================================

0 commit comments

Comments
 (0)