Skip to content

Commit e6492cb

Browse files
committed
change function name
1 parent 151d87e commit e6492cb

File tree

6 files changed

+20
-7
lines changed

6 files changed

+20
-7
lines changed

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,10 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
392392
// Call 'after_scf' of ESolver_KS
393393
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);
394394

395-
// Transfer data from GPU to CPU in pw basis
396-
this->stp.copy_g2c(this->device);
397-
398395
// Output quantities
399396
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
400397
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
401-
this->ctx, this->Pgrid, PARAM.inp);
398+
this->ctx, this->device, this->Pgrid, PARAM.inp);
402399

403400
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
404401
}

source/source_esolver/esolver_ks_pw.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
6060
// DFT-1/2 method
6161
VSep* vsep_cell = nullptr;
6262

63+
// for get_pchg and get_wf, use ctx as input of fft
6364
Device* ctx = {};
6465

66+
// for device to host data transformation
6567
base_device::AbacusDevice_t device = {};
6668

6769
};

source/source_io/ctrl_output_pw.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,16 @@ void ModuleIO::ctrl_scf_pw(const int istep,
9292
const ModulePW::PW_Basis_Big *pw_big,
9393
Setup_Psi<T, Device> &stp,
9494
const Device* ctx,
95+
const base_device::AbacusDevice_t &device,
9596
const Parallel_Grid &para_grid,
9697
const Input_para& inp)
9798
{
9899
ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw");
99100
ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw");
100101

102+
// Transfer data from device (GPU) to host (CPU) in pw basis
103+
stp.copy_d2h(device);
104+
101105
//----------------------------------------------------------
102106
//! 4) Compute density of states (DOS)
103107
//----------------------------------------------------------
@@ -382,6 +386,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_CPU
382386
const ModulePW::PW_Basis_Big *pw_big,
383387
Setup_Psi<std::complex<float>, base_device::DEVICE_CPU> &stp,
384388
const base_device::DEVICE_CPU* ctx,
389+
const base_device::AbacusDevice_t &device,
385390
const Parallel_Grid &para_grid,
386391
const Input_para& inp);
387392

@@ -398,6 +403,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_CP
398403
const ModulePW::PW_Basis_Big *pw_big,
399404
Setup_Psi<std::complex<double>, base_device::DEVICE_CPU> &stp,
400405
const base_device::DEVICE_CPU* ctx,
406+
const base_device::AbacusDevice_t &device,
401407
const Parallel_Grid &para_grid,
402408
const Input_para& inp);
403409

@@ -415,6 +421,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_GPU
415421
const ModulePW::PW_Basis_Big *pw_big,
416422
Setup_Psi<std::complex<float>, base_device::DEVICE_GPU> &stp,
417423
const base_device::DEVICE_GPU* ctx,
424+
const base_device::AbacusDevice_t &device,
418425
const Parallel_Grid &para_grid,
419426
const Input_para& inp);
420427

@@ -431,6 +438,7 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_GP
431438
const ModulePW::PW_Basis_Big *pw_big,
432439
Setup_Psi<std::complex<double>, base_device::DEVICE_GPU> &stp,
433440
const base_device::DEVICE_GPU* ctx,
441+
const base_device::AbacusDevice_t &device,
434442
const Parallel_Grid &para_grid,
435443
const Input_para& inp);
436444
#endif

source/source_io/ctrl_output_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void ctrl_scf_pw(const int istep,
3131
const ModulePW::PW_Basis_Big *pw_big,
3232
Setup_Psi<T, Device> &stp,
3333
const Device* ctx,
34+
const base_device::AbacusDevice_t &device, // mohan add 2025-10-15
3435
const Parallel_Grid &para_grid,
3536
const Input_para& inp);
3637

source/source_psi/setup_psi.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,19 @@ void Setup_Psi<T, Device>::init(hamilt::Hamilt<T, Device>* p_hamilt)
6464

6565
// Transfer data from GPU to CPU in pw basis
6666
template <typename T, typename Device>
67-
void Setup_Psi<T, Device>::copy_g2c(base_device::AbacusDevice_t &device)
67+
void Setup_Psi<T, Device>::copy_d2h(const base_device::AbacusDevice_t &device)
6868
{
6969
if (device == base_device::GpuDevice)
7070
{
7171
castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(),
7272
this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(),
7373
this->psi_cpu[0].size());
7474
}
75+
else
76+
{
77+
// do nothing
78+
}
79+
return;
7580
}
7681

7782

source/source_psi/setup_psi.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class Setup_Psi
5858

5959
void update_psi_d();
6060

61-
// Transfer data from GPU to CPU in pw basis
62-
void copy_g2c(base_device::AbacusDevice_t &device);
61+
// Transfer data from device to host in pw basis
62+
void copy_d2h(const base_device::AbacusDevice_t &device);
6363

6464
void clean();
6565

0 commit comments

Comments
 (0)