Skip to content

Commit fbd4b5b

Browse files
Critsium-xyWuming-HUST
authored andcommitted
[Refactor] Move hardware initializer out from esolver code (deepmodeling#6494)
* Move hardware initializer out from esolver * Remove useless codes * Remove finalize code out
1 parent b864781 commit fbd4b5b

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,6 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
5757
this->classname = "ESolver_KS_PW";
5858
this->basisname = "PW";
5959
this->device = base_device::get_device_type<Device>(this->ctx);
60-
61-
#if ((defined __CUDA) || (defined __ROCM))
62-
if (this->device == base_device::GpuDevice)
63-
{
64-
ModuleBase::createGpuBlasHandle();
65-
hsolver::createGpuSolverHandle();
66-
container::kernels::createGpuBlasHandle();
67-
container::kernels::createGpuSolverHandle();
68-
}
69-
#endif
70-
71-
#ifdef __DSP
72-
std::cout << " ** Initializing DSP Hardware..." << std::endl;
73-
mtfunc::dspInitHandle(GlobalV::MY_RANK);
74-
#endif
7560
}
7661

7762
template <typename T, typename Device>
@@ -87,21 +72,6 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
8772
this->pelec = nullptr;
8873
}
8974

90-
if (this->device == base_device::GpuDevice)
91-
{
92-
#if defined(__CUDA) || defined(__ROCM)
93-
ModuleBase::destoryBLAShandle();
94-
hsolver::destroyGpuSolverHandle();
95-
container::kernels::destroyGpuBlasHandle();
96-
container::kernels::destroyGpuSolverHandle();
97-
#endif
98-
}
99-
100-
#ifdef __DSP
101-
std::cout << " ** Closing DSP Hardware..." << std::endl;
102-
mtfunc::dspDestoryHandle(GlobalV::MY_RANK);
103-
#endif
104-
10575
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
10676
{
10777
delete this->kspw_psi;

source/source_main/driver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class Driver
3737

3838
// the actual calculations
3939
void driver_run();
40+
41+
// Init harewares according to Input parameters
42+
void init_hardware();
43+
void finalize_hardware();
4044
};
4145

4246
#endif

source/source_main/driver_run.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
#include "source_io/para_json.h"
77
#include "source_io/print_info.h"
88
#include "source_md/run_md.h"
9+
#include "source_base/module_device/device.h"
10+
#include "source_base/module_device/memory_op.h"
11+
#include "source_base/kernels/math_kernel_op.h"
12+
#include "source_hsolver/kernels/dngvd_op.h"
13+
14+
#include <ATen/kernels/blas.h>
15+
#include <ATen/kernels/lapack.h>
16+
17+
#ifdef __DSP
18+
#include "source_base/kernels/dsp/dsp_connector.h"
19+
#endif
920

1021
/**
1122
* @brief This is the driver function which defines the workflow of ABACUS
@@ -47,6 +58,8 @@ void Driver::driver_run()
4758
unitcell::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);
4859

4960
//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
61+
this->init_hardware();
62+
5063
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell);
5164

5265
//! 3: initialize Esolver and fill json-structure
@@ -93,9 +106,46 @@ void Driver::driver_run()
93106
p_esolver->after_all_runners(ucell);
94107

95108
ModuleESolver::clean_esolver(p_esolver);
109+
this->finalize_hardware();
96110

97111
//! 6: output the json file
98112
Json::create_Json(&ucell, PARAM);
99113

100114
return;
101115
}
116+
117+
void Driver::init_hardware()
118+
{
119+
#if ((defined __CUDA) || (defined __ROCM))
120+
if (PARAM.inp.device == "gpu")
121+
{
122+
ModuleBase::createGpuBlasHandle();
123+
hsolver::createGpuSolverHandle();
124+
container::kernels::createGpuBlasHandle();
125+
container::kernels::createGpuSolverHandle();
126+
}
127+
#endif
128+
129+
#ifdef __DSP
130+
std::cout << " ** Initializing DSP Hardware..." << std::endl;
131+
mtfunc::dspInitHandle(GlobalV::MY_RANK);
132+
#endif
133+
}
134+
135+
void Driver::finalize_hardware()
136+
{
137+
#if defined(__CUDA) || defined(__ROCM)
138+
if (PARAM.inp.device == "gpu")
139+
{
140+
ModuleBase::destoryBLAShandle();
141+
hsolver::destroyGpuSolverHandle();
142+
container::kernels::destroyGpuBlasHandle();
143+
container::kernels::destroyGpuSolverHandle();
144+
}
145+
#endif
146+
147+
#ifdef __DSP
148+
std::cout << " ** Closing DSP Hardware..." << std::endl;
149+
mtfunc::dspDestoryHandle(GlobalV::MY_RANK);
150+
#endif
151+
}

0 commit comments

Comments
 (0)