diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 36435c4645..1f3c23da5a 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -56,21 +56,6 @@ ESolver_KS_PW::ESolver_KS_PW() this->classname = "ESolver_KS_PW"; this->basisname = "PW"; this->device = base_device::get_device_type(this->ctx); - -#if ((defined __CUDA) || (defined __ROCM)) - if (this->device == base_device::GpuDevice) - { - ModuleBase::createGpuBlasHandle(); - hsolver::createGpuSolverHandle(); - container::kernels::createGpuBlasHandle(); - container::kernels::createGpuSolverHandle(); - } -#endif - -#ifdef __DSP - std::cout << " ** Initializing DSP Hardware..." << std::endl; - mtfunc::dspInitHandle(GlobalV::MY_RANK); -#endif } template @@ -86,21 +71,6 @@ ESolver_KS_PW::~ESolver_KS_PW() this->pelec = nullptr; } - if (this->device == base_device::GpuDevice) - { -#if defined(__CUDA) || defined(__ROCM) - ModuleBase::destoryBLAShandle(); - hsolver::destroyGpuSolverHandle(); - container::kernels::destroyGpuBlasHandle(); - container::kernels::destroyGpuSolverHandle(); -#endif - } - -#ifdef __DSP - std::cout << " ** Closing DSP Hardware..." << std::endl; - mtfunc::dspDestoryHandle(GlobalV::MY_RANK); -#endif - if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { delete this->kspw_psi; diff --git a/source/source_main/driver.h b/source/source_main/driver.h index d0ef395875..9bdcf1d285 100644 --- a/source/source_main/driver.h +++ b/source/source_main/driver.h @@ -37,6 +37,10 @@ class Driver // the actual calculations void driver_run(); + + // Init harewares according to Input parameters + void init_hardware(); + void finalize_hardware(); }; #endif diff --git a/source/source_main/driver_run.cpp b/source/source_main/driver_run.cpp index 4743b95bf8..579b8fbd08 100644 --- a/source/source_main/driver_run.cpp +++ b/source/source_main/driver_run.cpp @@ -6,6 +6,17 @@ #include "source_io/para_json.h" #include "source_io/print_info.h" #include "source_md/run_md.h" +#include "source_base/module_device/device.h" +#include "source_base/module_device/memory_op.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_hsolver/kernels/dngvd_op.h" + +#include +#include + +#ifdef __DSP +#include "source_base/kernels/dsp/dsp_connector.h" +#endif /** * @brief This is the driver function which defines the workflow of ABACUS @@ -47,6 +58,8 @@ void Driver::driver_run() unitcell::check_atomic_stru(ucell, PARAM.inp.min_dist_coef); //! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`) + this->init_hardware(); + ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell); //! 3: initialize Esolver and fill json-structure @@ -93,9 +106,46 @@ void Driver::driver_run() p_esolver->after_all_runners(ucell); ModuleESolver::clean_esolver(p_esolver); + this->finalize_hardware(); //! 6: output the json file Json::create_Json(&ucell, PARAM); return; } + +void Driver::init_hardware() +{ +#if ((defined __CUDA) || (defined __ROCM)) + if (PARAM.inp.device == "gpu") + { + ModuleBase::createGpuBlasHandle(); + hsolver::createGpuSolverHandle(); + container::kernels::createGpuBlasHandle(); + container::kernels::createGpuSolverHandle(); + } +#endif + +#ifdef __DSP + std::cout << " ** Initializing DSP Hardware..." << std::endl; + mtfunc::dspInitHandle(GlobalV::MY_RANK); +#endif +} + +void Driver::finalize_hardware() +{ +#if defined(__CUDA) || defined(__ROCM) + if (PARAM.inp.device == "gpu") + { + ModuleBase::destoryBLAShandle(); + hsolver::destroyGpuSolverHandle(); + container::kernels::destroyGpuBlasHandle(); + container::kernels::destroyGpuSolverHandle(); + } +#endif + +#ifdef __DSP + std::cout << " ** Closing DSP Hardware..." << std::endl; + mtfunc::dspDestoryHandle(GlobalV::MY_RANK); +#endif +} \ No newline at end of file