Skip to content

Commit a105a3c

Browse files
committed
remove ctx in bundle
1 parent b96bb77 commit a105a3c

File tree

6 files changed

+26
-35
lines changed

6 files changed

+26
-35
lines changed

source/module_basis/module_pw/kernels/test/pw_op_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class TestModulePWPWMultiDevice : public ::testing::Test
7272
TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
7373
{
7474
std::vector<std::complex<double>> res(out_1.size(), std::complex<double>{0, 0});
75-
set_3d_fft_box_cpu_op()(cpu_ctx, this->npwk, box_index.data(), in_1.data(), res.data());
75+
set_3d_fft_box_cpu_op()(this->npwk, box_index.data(), in_1.data(), res.data());
7676
for (int ii = 0; ii < this->nxyz; ii++) {
7777
EXPECT_LT(std::abs(res[ii] - out_1[ii]), 1e-12);
7878
}
@@ -81,7 +81,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
8181
TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
8282
{
8383
std::vector<std::complex<double>> res(out_2.size(), std::complex<double>{0, 0});
84-
set_recip_to_real_output_cpu_op()(cpu_ctx, this->nxyz, this->add, this->factor, in_2.data(), res.data());
84+
set_recip_to_real_output_cpu_op()(this->nxyz, this->add, this->factor, in_2.data(), res.data());
8585
for (int ii = 0; ii < this->nxyz; ii++) {
8686
EXPECT_LT(std::abs(res[ii] - out_2[ii]), 1e-12);
8787
}
@@ -90,7 +90,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
9090
TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_cpu)
9191
{
9292
std::vector<std::complex<double>> res = out_3_init;
93-
set_real_to_recip_output_cpu_op()(cpu_ctx, this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
93+
set_real_to_recip_output_cpu_op()(this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
9494
for (int ii = 0; ii < out_3.size(); ii++) {
9595
EXPECT_LT(std::abs(res[ii] - out_3[ii]), 5e-6);
9696
}
@@ -109,7 +109,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_gpu)
109109
synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
110110
synchronize_memory_complex_h2d_op()(d_in_1, in_1.data(), in_1.size());
111111

112-
set_3d_fft_box_gpu_op()(gpu_ctx, this->npwk, d_box_index, d_in_1, d_res);
112+
set_3d_fft_box_gpu_op()(this->npwk, d_box_index, d_in_1, d_res);
113113

114114
synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
115115

@@ -130,7 +130,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_gpu)
130130
synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
131131
synchronize_memory_complex_h2d_op()(d_in_2, in_2.data(), in_2.size());
132132

133-
set_recip_to_real_output_gpu_op()(gpu_ctx, this->nxyz, this->add, this->factor, d_in_2, d_res);
133+
set_recip_to_real_output_gpu_op()(this->nxyz, this->add, this->factor, d_in_2, d_res);
134134

135135
synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
136136

@@ -153,7 +153,7 @@ TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_gpu)
153153
synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
154154
synchronize_memory_complex_h2d_op()(d_in_3, in_3.data(), in_3.size());
155155

156-
set_real_to_recip_output_gpu_op()(gpu_ctx, this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
156+
set_real_to_recip_output_gpu_op()(this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
157157

158158
synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
159159

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,30 +227,26 @@ void FFT_Bundle::fftxyc2r(std::complex<double>* in, double* out) const
227227
}
228228

229229
template <>
230-
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
231-
std::complex<float>* in,
230+
void FFT_Bundle::fft3D_forward(std::complex<float>* in,
232231
std::complex<float>* out) const
233232
{
234233
fft_float->fft3D_forward(in, out);
235234
}
236235
template <>
237-
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
238-
std::complex<double>* in,
236+
void FFT_Bundle::fft3D_forward(std::complex<double>* in,
239237
std::complex<double>* out) const
240238
{
241239
fft_double->fft3D_forward(in, out);
242240
}
243241

244242
template <>
245-
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
246-
std::complex<float>* in,
243+
void FFT_Bundle::fft3D_backward(std::complex<float>* in,
247244
std::complex<float>* out) const
248245
{
249246
fft_float->fft3D_backward(in, out);
250247
}
251248
template <>
252-
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
253-
std::complex<double>* in,
249+
void FFT_Bundle::fft3D_backward(std::complex<double>* in,
254250
std::complex<double>* out) const
255251
{
256252
fft_double->fft3D_backward(in, out);

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@ class FFT_Bundle
188188
template <typename FPTYPE>
189189
void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const;
190190

191-
template <typename FPTYPE, typename Device>
192-
void fft3D_forward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
193-
template <typename FPTYPE, typename Device>
194-
void fft3D_backward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
191+
template <typename FPTYPE>
192+
void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
193+
template <typename FPTYPE>
194+
void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
195195

196196
private:
197197
int fft_mode = 0;

source/module_basis/module_pw/pw_transform_gpu.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in,
2020
// in,
2121
// this->nrxx);
2222

23-
this->fft_bundle.fft3D_forward(ctx,
24-
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
23+
this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
2524
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
2625

2726
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
@@ -49,8 +48,7 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
4948
in,
5049
this->nrxx);
5150

52-
this->fft_bundle.fft3D_forward(ctx,
53-
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
51+
this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
5452
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
5553

5654
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
@@ -83,8 +81,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
8381
this->ig2isz,
8482
in,
8583
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
86-
this->fft_bundle.fft3D_backward(ctx,
87-
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
84+
this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
8885
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
8986

9087
set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
@@ -115,8 +112,7 @@ template <typename FPTYPE>
115112
this->ig2isz,
116113
in,
117114
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
118-
this->fft_bundle.fft3D_backward(ctx,
119-
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
115+
this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
120116
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
121117

122118
set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
357357
in,
358358
this->nrxx);
359359

360-
this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
360+
this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
361361

362362
const int startig = ik * this->npwk_max;
363363
const int npw_k = this->npwk[ik];
@@ -388,7 +388,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
388388
in,
389389
this->nrxx);
390390

391-
this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
391+
this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
392392

393393
const int startig = ik * this->npwk_max;
394394
const int npw_k = this->npwk[ik];
@@ -426,7 +426,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
426426
this->ig2ixyz_k + startig,
427427
in,
428428
this->fft_bundle.get_auxr_3d_data<float>());
429-
this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
429+
this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
430430

431431
set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx,
432432
add,
@@ -460,7 +460,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
460460
this->ig2ixyz_k + startig,
461461
in,
462462
this->fft_bundle.get_auxr_3d_data<double>());
463-
this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
463+
this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
464464

465465
set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx,
466466
add,

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
2727

2828
// 3d fft
2929
this->fft_bundle.resource_handler(1);
30-
this->fft_bundle.fft3D_forward(gpux,
31-
auxr,
30+
this->fft_bundle.fft3D_forward(auxr,
3231
auxr);
3332
this->fft_bundle.resource_handler(0);
3433
// copy the result from the auxr to the out ,while consider the add
@@ -60,7 +59,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
6059
set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
6160
// use 3d fft backward
6261
this->fft_bundle.resource_handler(1);
63-
this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
62+
this->fft_bundle.fft3D_backward(auxr, auxr);
6463
this->fft_bundle.resource_handler(0);
6564
if (add)
6665
{
@@ -109,15 +108,15 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
109108
set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
110109

111110
// use 3d fft backward
112-
this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
111+
this->fft_bundle.fft3D_backward(auxr, auxr);
113112

114113
for (int ir = 0; ir < size; ir++)
115114
{
116115
auxr[ir] *= input1[ir];
117116
}
118117

119118
// 3d fft
120-
this->fft_bundle.fft3D_forward(gpux, auxr, auxr);
119+
this->fft_bundle.fft3D_forward(auxr, auxr);
121120
// copy the result from the auxr to the out ,while consider the add
122121
set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
123122
this->nxyz,

0 commit comments

Comments
 (0)