Skip to content

Commit 937eb35

Browse files
committed
Remove ctx in dot_real_op
1 parent 105fd19 commit 937eb35

File tree

9 files changed

+25
-34
lines changed

9 files changed

+25
-34
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,7 @@ void calc_grad_with_block_op<T, base_device::DEVICE_GPU>::operator()(const Real*
451451
}
452452

453453
template <>
454-
double dot_real_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
455-
const int& dim,
454+
double dot_real_op<double, base_device::DEVICE_GPU>::operator()(const int& dim,
456455
const double* psi_L,
457456
const double* psi_R,
458457
const bool reduce)
@@ -488,17 +487,15 @@ inline FPTYPE dot_complex_wrapper(const base_device::DEVICE_GPU* d,
488487
}
489488

490489
template <>
491-
float dot_real_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
492-
const int& dim,
490+
float dot_real_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int& dim,
493491
const std::complex<float>* psi_L,
494492
const std::complex<float>* psi_R,
495493
const bool reduce)
496494
{
497495
return dot_complex_wrapper(d, dim, psi_L, psi_R, reduce);
498496
}
499497
template <>
500-
double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
501-
const int& dim,
498+
double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int& dim,
502499
const std::complex<double>* psi_L,
503500
const std::complex<double>* psi_R,
504501
const bool reduce)

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_CPU>
110110
template <typename FPTYPE>
111111
struct dot_real_op<FPTYPE, base_device::DEVICE_CPU>
112112
{
113-
FPTYPE operator()(const base_device::DEVICE_CPU* d,
114-
const int& dim,
113+
FPTYPE operator()(const int& dim,
115114
const FPTYPE* psi_L,
116115
const FPTYPE* psi_R,
117116
const bool reduce)
@@ -129,8 +128,7 @@ struct dot_real_op<FPTYPE, base_device::DEVICE_CPU>
129128
template <typename FPTYPE>
130129
struct dot_real_op<std::complex<FPTYPE>, base_device::DEVICE_CPU>
131130
{
132-
FPTYPE operator()(const base_device::DEVICE_CPU* d,
133-
const int& dim,
131+
FPTYPE operator()(const int& dim,
134132
const std::complex<FPTYPE>* psi_L,
135133
const std::complex<FPTYPE>* psi_R,
136134
const bool reduce)

source/module_base/kernels/math_kernel_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ template <typename T, typename Device> struct dot_real_op {
9898
///
9999
/// \return
100100
/// FPTYPE : dot product result
101-
Real operator()(const Device *d, const int &dim, const T *psi_L,
101+
Real operator()(const int &dim, const T *psi_L,
102102
const T *psi_R, const bool reduce = true);
103103
};
104104

@@ -347,7 +347,7 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_GPU> {
347347
// Partially specialize functor for base_device::GpuDevice.
348348
template <typename T> struct dot_real_op<T, base_device::DEVICE_GPU> {
349349
using Real = typename GetTypeReal<T>::type;
350-
Real operator()(const base_device::DEVICE_GPU *d, const int &dim,
350+
Real operator()(const int &dim,
351351
const T *psi_L, const T *psi_R, const bool reduce = true);
352352
};
353353

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,7 @@ void calc_grad_with_block_op<T, base_device::DEVICE_GPU>::operator()(const Real*
374374
}
375375

376376
template <>
377-
double dot_real_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
378-
const int& dim,
377+
double dot_real_op<double, base_device::DEVICE_GPU>::operator()(const int& dim,
379378
const double* psi_L,
380379
const double* psi_R,
381380
const bool reduce)
@@ -411,17 +410,15 @@ inline FPTYPE dot_complex_wrapper(const base_device::DEVICE_GPU* d,
411410
return result;
412411
}
413412
template <>
414-
float dot_real_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
415-
const int& dim,
413+
float dot_real_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int& dim,
416414
const std::complex<float>* psi_L,
417415
const std::complex<float>* psi_R,
418416
const bool reduce)
419417
{
420418
return dot_complex_wrapper(d, dim, psi_L, psi_R, reduce);
421419
}
422420
template <>
423-
double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
424-
const int& dim,
421+
double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int& dim,
425422
const std::complex<double>* psi_L,
426423
const std::complex<double>* psi_R,
427424
const bool reduce)

source/module_base/kernels/test/math_kernel_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class TestModuleHsolverMathKernel : public ::testing::Test
257257
// base_device::AbacusDevice_t device = base_device::CpuDevice, const bool reduce = true);
258258
TEST_F(TestModuleHsolverMathKernel, zdot_real_op_cpu)
259259
{
260-
double result = zdot_real_cpu_op()(cpu_ctx, dim, psi_L.data(), psi_R.data(), false);
260+
double result = zdot_real_cpu_op()(dim, psi_L.data(), psi_R.data(), false);
261261
EXPECT_LT(fabs(result - expected_result), 1e-12);
262262
}
263263

@@ -376,7 +376,7 @@ TEST_F(TestModuleHsolverMathKernel, zdot_real_op_gpu)
376376
synchronize_memory_op()(psi_L_dev, psi_L.data(), psi_L.size());
377377
synchronize_memory_op()(psi_R_dev, psi_R.data(), psi_R.size());
378378
ModuleBase::createGpuBlasHandle();
379-
double result = zdot_real_gpu_op()(gpu_ctx, dim, psi_L_dev, psi_R_dev, false);
379+
double result = zdot_real_gpu_op()(dim, psi_L_dev, psi_R_dev, false);
380380
ModuleBase::destoryBLAShandle();
381381
EXPECT_LT(fabs(result - expected_result), 1e-12);
382382
delete_memory_op()(psi_L_dev);

source/module_hsolver/diago_cg.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
127127
this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)>
128128
this->hpsi_func_(phi_m, hphi); // hphi = H|psi(m)>
129129

130-
eigen_pack[m] = dot_real_op()(ctx_, this->n_basis_, phi_m.data<T>(), hphi.data<T>());
130+
eigen_pack[m] = dot_real_op()(this->n_basis_, phi_m.data<T>(), hphi.data<T>());
131131

132132
int iter = 0;
133133
Real gg_last = 0.0;
@@ -231,9 +231,9 @@ void DiagoCG<T, Device>::calc_grad(const ct::Tensor& prec,
231231

232232
// Update lambda !
233233
// (4) <psi|SPH|psi >
234-
const Real eh = ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, sphi.data<T>(), grad.data<T>());
234+
const Real eh = ModuleBase::dot_real_op<T, Device>()(this->n_basis_, sphi.data<T>(), grad.data<T>());
235235
// (5) <psi|SPS|psi >
236-
const Real es = ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, sphi.data<T>(), pphi.data<T>());
236+
const Real es = ModuleBase::dot_real_op<T, Device>()(this->n_basis_, sphi.data<T>(), pphi.data<T>());
237237
const Real lambda = eh / es;
238238

239239
// Update g!
@@ -328,7 +328,7 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
328328
// gg_inter = <g|g0>
329329
// Attention : the 'g' in g0 is getted last time
330330
gg_inter
331-
= ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, grad.data<T>(), g0.data<T>()); // b means before
331+
= ModuleBase::dot_real_op<T, Device>()(this->n_basis_, grad.data<T>(), g0.data<T>()); // b means before
332332
}
333333

334334
// (2) Update for g0!
@@ -346,7 +346,7 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
346346

347347
// (3) Update gg_now!
348348
// gg_now = < g|P|scg > = < g|g0 >
349-
const Real gg_now = ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, grad.data<T>(), g0.data<T>());
349+
const Real gg_now = ModuleBase::dot_real_op<T, Device>()(this->n_basis_, grad.data<T>(), g0.data<T>());
350350

351351
if (iter == 0)
352352
{
@@ -404,15 +404,15 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
404404
ct::Tensor& sphi,
405405
ct::Tensor& hphi)
406406
{
407-
cg_norm = sqrt(ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, cg.data<T>(), scg.data<T>()));
407+
cg_norm = sqrt(ModuleBase::dot_real_op<T, Device>()(this->n_basis_, cg.data<T>(), scg.data<T>()));
408408

409409
if (cg_norm < 1.0e-10)
410410
return true;
411411

412412
const Real a0
413-
= ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, phi_m.data<T>(), pphi.data<T>()) * 2.0 / cg_norm;
413+
= ModuleBase::dot_real_op<T, Device>()(this->n_basis_, phi_m.data<T>(), pphi.data<T>()) * 2.0 / cg_norm;
414414
const Real b0
415-
= ModuleBase::dot_real_op<T, Device>()(ctx_, this->n_basis_, cg.data<T>(), pphi.data<T>()) / (cg_norm * cg_norm);
415+
= ModuleBase::dot_real_op<T, Device>()(this->n_basis_, cg.data<T>(), pphi.data<T>()) / (cg_norm * cg_norm);
416416

417417
const Real e0 = eigen;
418418
theta = atan(a0 / (e0 - b0)) / 2.0;
@@ -538,7 +538,7 @@ void DiagoCG<T, Device>::schmit_orth(const int& m, const ct::Tensor& psi, const
538538
}*/
539539
//>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
540540
auto psi_norm = ct::extract<Real>(lagrange_so[m])
541-
- dot_real_op()(ctx_, m, lagrange_so.data<T>(), lagrange_so.data<T>(), false);
541+
- dot_real_op()(m, lagrange_so.data<T>(), lagrange_so.data<T>(), false);
542542

543543
if (psi_norm <= 0.0)
544544
{

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
375375
std::vector<Real> psi_norm(notconv, 0.0);
376376
for (size_t i = 0; i < notconv; i++)
377377
{
378-
psi_norm[i] = ModuleBase::dot_real_op<T, Device>()(this->ctx,
379-
this->dim,
378+
psi_norm[i] = ModuleBase::dot_real_op<T, Device>()(this->dim,
380379
psi_iter + (nbase + i) * this->dim,
381380
psi_iter + (nbase + i) * this->dim,
382381
true);

source/module_hsolver/diago_david.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ void DiagoDavid<T, Device>::SchmidtOrth(const int& dim,
953953
1);
954954

955955
// psi_norm = psi_norm - lagrange_m · lagrange_m
956-
psi_norm -= ModuleBase::dot_real_op<T, Device>()(this->ctx, m, lagrange_m, lagrange_m, false);
956+
psi_norm -= ModuleBase::dot_real_op<T, Device>()(m, lagrange_m, lagrange_m, false);
957957

958958
// for (int j = 0; j < m; j++)
959959
// {

source/module_hsolver/kernels/test/perf_math_kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture {
162162

163163
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_zdot_real_cpu_op)(benchmark::State& state) {
164164
for (auto _ : state) {
165-
double result = zdot_real_cpu_op()(cpu_ctx, dim_vector, test_zvector_a, test_zvector_b, false);
165+
double result = zdot_real_cpu_op()(dim_vector, test_zvector_a, test_zvector_b, false);
166166
}
167167
}
168168

@@ -232,7 +232,7 @@ BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_zdot_real_gpu_op)(benchmark::
232232

233233
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_zdot_real_gpu_op)(benchmark::State& state) {
234234
for (auto _ : state) {
235-
double result = zdot_real_gpu_op()(gpu_ctx, dim_vector, test_zvector_a_gpu, test_zvector_b_gpu, false);
235+
double result = zdot_real_gpu_op()(dim_vector, test_zvector_a_gpu, test_zvector_b_gpu, false);
236236
}
237237
}
238238

0 commit comments

Comments
 (0)