Skip to content

Commit 6e1dd33

Browse files
committed
Move hardware initializer out from esolver
1 parent 9e6bb70 commit 6e1dd33

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,18 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
5858
this->device = base_device::get_device_type<Device>(this->ctx);
5959

6060
#if ((defined __CUDA) || (defined __ROCM))
61-
if (this->device == base_device::GpuDevice)
61+
/*if (this->device == base_device::GpuDevice)
6262
{
6363
ModuleBase::createGpuBlasHandle();
6464
hsolver::createGpuSolverHandle();
6565
container::kernels::createGpuBlasHandle();
6666
container::kernels::createGpuSolverHandle();
67-
}
67+
}*/
6868
#endif
6969

7070
#ifdef __DSP
71-
std::cout << " ** Initializing DSP Hardware..." << std::endl;
72-
mtfunc::dspInitHandle(GlobalV::MY_RANK);
71+
// std::cout << " ** Initializing DSP Hardware..." << std::endl;
72+
// mtfunc::dspInitHandle(GlobalV::MY_RANK);
7373
#endif
7474
}
7575

source/source_main/driver.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class Driver
3737

3838
// the actual calculations
3939
void driver_run();
40+
41+
// Init harewares according to Input parameters
42+
void init_hardware();
4043
};
4144

4245
#endif

source/source_main/driver_run.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
#include "source_io/para_json.h"
77
#include "source_io/print_info.h"
88
#include "source_md/run_md.h"
9+
#include "source_base/module_device/device.h"
10+
#include "source_base/module_device/memory_op.h"
11+
#include "source_base/kernels/math_kernel_op.h"
12+
#include "source_hsolver/kernels/dngvd_op.h"
13+
14+
#include <ATen/kernels/blas.h>
15+
#include <ATen/kernels/lapack.h>
16+
17+
#ifdef __DSP
18+
#include "source_base/kernels/dsp/dsp_connector.h"
19+
#endif
920

1021
/**
1122
* @brief This is the driver function which defines the workflow of ABACUS
@@ -47,6 +58,8 @@ void Driver::driver_run()
4758
unitcell::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);
4859

4960
//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
61+
this->init_hardware();
62+
5063
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell);
5164

5265
//! 3: initialize Esolver and fill json-structure
@@ -99,3 +112,20 @@ void Driver::driver_run()
99112

100113
return;
101114
}
115+
116+
void Driver::init_hardware(){
117+
#if ((defined __CUDA) || (defined __ROCM))
118+
if (PARAM.inp.device == "gpu")
119+
{
120+
ModuleBase::createGpuBlasHandle();
121+
hsolver::createGpuSolverHandle();
122+
container::kernels::createGpuBlasHandle();
123+
container::kernels::createGpuSolverHandle();
124+
}
125+
#endif
126+
127+
#ifdef __DSP
128+
std::cout << " ** Initializing DSP Hardware..." << std::endl;
129+
mtfunc::dspInitHandle(GlobalV::MY_RANK);
130+
#endif
131+
}

0 commit comments

Comments
 (0)