Skip to content

Commit f678433

Browse files
committed
add the cuda support for the ft
1 parent cdc3e96 commit f678433

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
347347
base_device::memory::synchronize_memory_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()(
348348
ctx,
349349
ctx,
350-
this->ft.get_auxr_3d_data<float>(),
350+
this->ft1.get_auxr_3d_data<float>(),
351351
in,
352352
this->nrxx);
353353

354-
this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data<float>(), this->ft.get_auxr_3d_data<float>());
354+
this->ft1.fft3D_forward(ctx, this->ft1.get_auxr_3d_data<float>(), this->ft1.get_auxr_3d_data<float>());
355355

356356
const int startig = ik * this->npwk_max;
357357
const int npw_k = this->npwk[ik];
@@ -361,7 +361,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
361361
add,
362362
factor,
363363
this->ig2ixyz_k + startig,
364-
this->ft.get_auxr_3d_data<float>(),
364+
this->ft1.get_auxr_3d_data<float>(),
365365
out);
366366
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
367367
}
@@ -381,11 +381,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
381381
base_device::DEVICE_GPU,
382382
base_device::DEVICE_GPU>()(ctx,
383383
ctx,
384-
this->ft.get_auxr_3d_data<double>(),
384+
this->ft1.get_auxr_3d_data<double>(),
385385
in,
386386
this->nrxx);
387387

388-
this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data<double>(), this->ft.get_auxr_3d_data<double>());
388+
this->ft1.fft3D_forward(ctx, this->ft1.get_auxr_3d_data<double>(), this->ft1.get_auxr_3d_data<double>());
389389

390390
const int startig = ik * this->npwk_max;
391391
const int npw_k = this->npwk[ik];
@@ -395,7 +395,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
395395
add,
396396
factor,
397397
this->ig2ixyz_k + startig,
398-
this->ft.get_auxr_3d_data<double>(),
398+
this->ft1.get_auxr_3d_data<double>(),
399399
out);
400400
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
401401
}
@@ -411,10 +411,10 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
411411
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
412412
assert(this->gamma_only == false);
413413
assert(this->poolnproc == 1);
414-
// ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data<float>(), this->nxyz);
414+
// ModuleBase::GlobalFunc::ZEROS(ft1.get_auxr_3d_data<float>(), this->nxyz);
415415
base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()(
416416
ctx,
417-
this->ft.get_auxr_3d_data<float>(),
417+
this->ft1.get_auxr_3d_data<float>(),
418418
0,
419419
this->nxyz);
420420

@@ -425,14 +425,14 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
425425
npw_k,
426426
this->ig2ixyz_k + startig,
427427
in,
428-
this->ft.get_auxr_3d_data<float>());
429-
this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data<float>(), this->ft.get_auxr_3d_data<float>());
428+
this->ft1.get_auxr_3d_data<float>());
429+
this->ft1.fft3D_backward(ctx, this->ft1.get_auxr_3d_data<float>(), this->ft1.get_auxr_3d_data<float>());
430430

431431
set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx,
432432
this->nrxx,
433433
add,
434434
factor,
435-
this->ft.get_auxr_3d_data<float>(),
435+
this->ft1.get_auxr_3d_data<float>(),
436436
out);
437437

438438
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
@@ -448,10 +448,10 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
448448
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
449449
assert(this->gamma_only == false);
450450
assert(this->poolnproc == 1);
451-
// ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data<double>(), this->nxyz);
451+
// ModuleBase::GlobalFunc::ZEROS(ft1.get_auxr_3d_data<double>(), this->nxyz);
452452
base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()(
453453
ctx,
454-
this->ft.get_auxr_3d_data<double>(),
454+
this->ft1.get_auxr_3d_data<double>(),
455455
0,
456456
this->nxyz);
457457

@@ -462,14 +462,14 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
462462
npw_k,
463463
this->ig2ixyz_k + startig,
464464
in,
465-
this->ft.get_auxr_3d_data<double>());
466-
this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data<double>(), this->ft.get_auxr_3d_data<double>());
465+
this->ft1.get_auxr_3d_data<double>());
466+
this->ft1.fft3D_backward(ctx, this->ft1.get_auxr_3d_data<double>(), this->ft1.get_auxr_3d_data<double>());
467467

468468
set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx,
469469
this->nrxx,
470470
add,
471471
factor,
472-
this->ft.get_auxr_3d_data<double>(),
472+
this->ft1.get_auxr_3d_data<double>(),
473473
out);
474474

475475
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");

0 commit comments

Comments
 (0)