@@ -92,12 +92,16 @@ void ModuleIO::ctrl_scf_pw(const int istep,
9292 const ModulePW::PW_Basis_Big *pw_big,
9393 Setup_Psi<T, Device> &stp,
9494 const Device* ctx,
95+ const base_device::AbacusDevice_t &device,
9596 const Parallel_Grid ¶_grid,
9697 const Input_para& inp)
9798{
9899 ModuleBase::TITLE (" ModuleIO" , " ctrl_scf_pw" );
99100 ModuleBase::timer::tick (" ModuleIO" , " ctrl_scf_pw" );
100101
102+ // Transfer data from device (GPU) to host (CPU) in pw basis
103+ stp.copy_d2h (device);
104+
101105 // ----------------------------------------------------------
102106 // ! 4) Compute density of states (DOS)
103107 // ----------------------------------------------------------
@@ -382,6 +386,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_CPU
382386 const ModulePW::PW_Basis_Big *pw_big,
383387 Setup_Psi<std::complex <float >, base_device::DEVICE_CPU> &stp,
384388 const base_device::DEVICE_CPU* ctx,
389+ const base_device::AbacusDevice_t &device,
385390 const Parallel_Grid ¶_grid,
386391 const Input_para& inp);
387392
@@ -398,6 +403,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_CP
398403 const ModulePW::PW_Basis_Big *pw_big,
399404 Setup_Psi<std::complex <double >, base_device::DEVICE_CPU> &stp,
400405 const base_device::DEVICE_CPU* ctx,
406+ const base_device::AbacusDevice_t &device,
401407 const Parallel_Grid ¶_grid,
402408 const Input_para& inp);
403409
@@ -415,6 +421,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_GPU
415421 const ModulePW::PW_Basis_Big *pw_big,
416422 Setup_Psi<std::complex <float >, base_device::DEVICE_GPU> &stp,
417423 const base_device::DEVICE_GPU* ctx,
424+ const base_device::AbacusDevice_t &device,
418425 const Parallel_Grid ¶_grid,
419426 const Input_para& inp);
420427
@@ -431,6 +438,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_GP
431438 const ModulePW::PW_Basis_Big *pw_big,
432439 Setup_Psi<std::complex <double >, base_device::DEVICE_GPU> &stp,
433440 const base_device::DEVICE_GPU* ctx,
441+ const base_device::AbacusDevice_t &device,
434442 const Parallel_Grid ¶_grid,
435443 const Input_para& inp);
436444#endif
0 commit comments