Skip to content

Commit e1f3d90

Browse files
author
Kai Luo
committed
take develop a lot of code
1 parent 73ea9dd commit e1f3d90

File tree

210 files changed

+12020
-2851
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

210 files changed

+12020
-2851
lines changed

source/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ list(APPEND device_srcs
3535
source_pw/module_pwdft/kernels/meta_op.cpp
3636
source_pw/module_stodft/kernels/hpsi_norm_op.cpp
3737
source_basis/module_pw/kernels/pw_op.cpp
38-
source_hsolver/kernels/dngvd_op.cpp
38+
source_hsolver/kernels/hegvd_op.cpp
3939
source_hsolver/kernels/bpcg_kernel_op.cpp
4040
source_estate/kernels/elecstate_op.cpp
4141

@@ -70,7 +70,7 @@ if(USE_CUDA)
7070
source_pw/module_stodft/kernels/cuda/hpsi_norm_op.cu
7171
source_pw/module_pwdft/kernels/cuda/onsite_op.cu
7272
source_basis/module_pw/kernels/cuda/pw_op.cu
73-
source_hsolver/kernels/cuda/dngvd_op.cu
73+
source_hsolver/kernels/cuda/hegvd_op.cu
7474
source_hsolver/kernels/cuda/bpcg_kernel_op.cu
7575
source_estate/kernels/cuda/elecstate_op.cu
7676

@@ -101,7 +101,7 @@ if(USE_ROCM)
101101
source_pw/module_pwdft/kernels/rocm/onsite_op.hip.cu
102102
source_pw/module_stodft/kernels/rocm/hpsi_norm_op.hip.cu
103103
source_basis/module_pw/kernels/rocm/pw_op.hip.cu
104-
source_hsolver/kernels/rocm/dngvd_op.hip.cu
104+
source_hsolver/kernels/rocm/hegvd_op.hip.cu
105105
source_hsolver/kernels/rocm/bpcg_kernel_op.hip.cu
106106
source_estate/kernels/rocm/elecstate_op.hip.cu
107107

source/source_base/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@ add_library(
1111
OBJECT
1212
assoc_laguerre.cpp
1313
module_external/blas_connector_base.cpp
14-
module_external/blas_connector_l1.cpp
15-
module_external/blas_connector_l2.cpp
16-
module_external/blas_connector_l3.cpp
17-
module_external/lapack_connector.cpp
14+
module_external/blas_connector_vector.cpp
15+
module_external/blas_connector_matrix.cpp
1816
clebsch_gordan_coeff.cpp
1917
complexarray.cpp
2018
complexmatrix.cpp

source/source_base/cubic_spline.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "cubic_spline.h"
2-
#include "source_base/module_external/lapack_connector.h"
32

43
#include <cassert>
54
#include <algorithm>
@@ -9,6 +8,13 @@
98

109
using ModuleBase::CubicSpline;
1110

11+
extern "C"
12+
{
13+
// solve a tridiagonal linear system
14+
void dgtsv_(int* N, int* NRHS, double* DL, double* D, double* DU, double* B, int* LDB, int* INFO);
15+
};
16+
17+
1218
CubicSpline::BoundaryCondition::BoundaryCondition(BoundaryType type)
1319
: type(type)
1420
{
@@ -471,7 +477,8 @@ void CubicSpline::_build(
471477

472478
int nrhs = 1;
473479
int ldb = n;
474-
LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, dy, ldb);
480+
int info = 0;
481+
dgtsv_(&n, &nrhs, l, d, u, dy, &ldb, &info);
475482
}
476483
}
477484
}
@@ -545,8 +552,9 @@ void CubicSpline::_solve_cyctri(int n, double* d, double* u, double* l, double*
545552
d[n - 1] -= l[n - 1] * alpha / beta;
546553

547554
int nrhs = 2;
555+
int info = 0;
548556
int ldb = n;
549-
LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, bp.data(), ldb);
557+
dgtsv_(&n, &nrhs, l, d, u, bp.data(), &ldb, &info);
550558

551559
double fac = (beta * u[n - 1] * bp[0] + alpha * l[n - 1] * bp[n - 1])
552560
/ (1. + beta * u[n - 1] * bp[n] + alpha * l[n - 1] * bp[2 * n - 1]);

source/source_base/gather_math_lib_info.cpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void zgemm_i(const char *transa,
3030
GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
3131
GlobalV::ofs_info << "zgemm " << *transa << " " << *transb << " " << *m << " " << *n << " "
3232
<< *k << " " << *alpha << " " << *lda << " " << *ldb << " " << *beta << " " << *ldc << std::endl;
33-
BlasConnector::gemm_cm(*transa, *transb, *m, *n, *k, *alpha, a, *lda, b, *ldb, *beta, c, *ldc);
33+
zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
3434
}
3535

3636
void zaxpy_i(const int *N,
@@ -43,37 +43,38 @@ void zaxpy_i(const int *N,
4343
// std::cout << "zaxpy " << *N << std::endl;
4444
// alpha is a coefficient
4545
// incX, incY is always 1
46-
BlasConnector::axpy(*N, *alpha, X, *incX, Y, *incY);
46+
zaxpy_(N, alpha, X, incX, Y, incY);
4747
}
4848

49-
// void zhegvx_i(const int *itype,
50-
// const char *jobz,
51-
// const char *range,
52-
// const char *uplo,
53-
// const int *n,
54-
// std::complex<double> *a,
55-
// const int *lda,
56-
// std::complex<double> *b,
57-
// const int *ldb,
58-
// const double *vl,
59-
// const double *vu,
60-
// const int *il,
61-
// const int *iu,
62-
// const double *abstol,
63-
// int *m,
64-
// double *w,
65-
// std::complex<double> *z,
66-
// const int *ldz,
67-
// std::complex<double> *work,
68-
// const int *lwork,
69-
// double *rwork,
70-
// int *iwork,
71-
// int *ifail,
72-
// int *info)
73-
// {
74-
// GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
75-
// GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo
76-
// << " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu
77-
// << " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl;
78-
// LapackConnector::hegvx(LapackConnector::ColMajor, *itype, *jobz, *range, *uplo, *n, a, *lda, b, *ldb, *vl, *vu, *il, *iu, *abstol, m, w, z, *ldz, ifail);
79-
// }
49+
void zhegvx_i(const int *itype,
50+
const char *jobz,
51+
const char *range,
52+
const char *uplo,
53+
const int *n,
54+
std::complex<double> *a,
55+
const int *lda,
56+
std::complex<double> *b,
57+
const int *ldb,
58+
const double *vl,
59+
const double *vu,
60+
const int *il,
61+
const int *iu,
62+
const double *abstol,
63+
const int *m,
64+
double *w,
65+
std::complex<double> *z,
66+
const int *ldz,
67+
std::complex<double> *work,
68+
const int *lwork,
69+
double *rwork,
70+
int *iwork,
71+
int *ifail,
72+
int *info)
73+
{
74+
GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
75+
GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo
76+
<< " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu
77+
<< " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl;
78+
zhegvx_(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork,
79+
iwork, ifail, info);
80+
}

source/source_base/global_function.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,21 @@ inline void DCOPY(const T* a, T* b, const int& dim) {
182182
}
183183

184184
template <typename T>
185-
inline void COPYARRAY(const T* a, T* b, const long dim)
185+
inline void COPYARRAY(const T* a, T* b, const long dim);
186+
187+
template <>
188+
inline void COPYARRAY(const std::complex<double>* a, std::complex<double>* b, const long dim)
186189
{
187190
const int one = 1;
188-
BlasConnector::copy(dim, a, one, b, one);
191+
zcopy_(&dim, a, &one, b, &one);
189192
}
190193

194+
template <>
195+
inline void COPYARRAY(const double* a, double* b, const long dim)
196+
{
197+
const int one = 1;
198+
dcopy_(&dim, a, &one, b, &one);
199+
}
191200

192201
void BLOCK_HERE(const std::string& description);
193202

source/source_base/inverse_matrix.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Inverse_Matrix_Complex::~Inverse_Matrix_Complex()
1616
if(allocate)
1717
{
1818
delete[] e; //mohan fix bug 2012-04-02
19+
delete[] work2;
20+
delete[] rwork;
1921
allocate=false;
2022
}
2123
}
@@ -26,17 +28,23 @@ void Inverse_Matrix_Complex::init(const int &dim_in)
2628
if(allocate)
2729
{
2830
delete[] e; //mohan fix bug 2012-04-02
31+
delete[] work2;
32+
delete[] rwork;
2933
allocate=false;
3034
}
3135

3236
this->dim = dim_in;
3337

3438
assert(dim>0);
3539
this->e = new double[dim];
40+
this->lwork = 2*dim;
3641

3742
assert(lwork>0);
43+
this->work2 = new std::complex<double>[lwork];
3844

3945
assert(3*dim-2>0);
46+
this->rwork = new double[3*dim-2];
47+
this->info = 0;
4048
this->A.create(dim, dim);
4149
this->EA.create(dim, dim);
4250

@@ -51,7 +59,7 @@ void Inverse_Matrix_Complex::using_zheev( const ModuleBase::ComplexMatrix &Sin,
5159
ModuleBase::timer::tick("Inverse","using_zheev");
5260
this->A = Sin;
5361

54-
LapackConnector::heev(LapackConnector::RowMajor, 'V', 'U', dim, this->A.c, dim, e);
62+
LapackConnector::zheev('V', 'U', dim, this->A, dim, e, work2, lwork, rwork, &info);
5563

5664
for(int i=0; i<dim; i++)
5765
{
@@ -68,8 +76,11 @@ void Inverse_Matrix_Complex::using_zheev( const ModuleBase::ComplexMatrix &Sin,
6876

6977
void Inverse_Matrix_Real(const int dim, const double* in, double* out)
7078
{
79+
int info = 0;
7180
int lda = dim;
72-
std::vector<int> ipiv(dim);
81+
int lwork = 64 * dim;
82+
int* ipiv = new int[dim];
83+
double* work = new double[lwork];
7384

7485
for (int i = 0; i < dim; i++)
7586
{
@@ -79,7 +90,20 @@ void Inverse_Matrix_Real(const int dim, const double* in, double* out)
7990
}
8091
}
8192

82-
LapackConnector::getrf(LapackConnector::ColMajor, dim, dim, out, lda, ipiv.data());
83-
LapackConnector::getri(LapackConnector::ColMajor, dim, out, lda, ipiv.data());
93+
dgetrf_(&dim, &dim, out, &lda, ipiv, &info);
94+
if (info != 0)
95+
{
96+
std::cout << "ERROR: LAPACK dgetrf error, info = " << info << std::endl;
97+
exit(1);
98+
}
99+
dgetri_(&dim, out, &lda, ipiv, work, &lwork, &info);
100+
if (info != 0)
101+
{
102+
std::cout << "ERROR: LAPACK dgetri error, info = " << info << std::endl;
103+
exit(1);
104+
}
105+
106+
delete[] ipiv;
107+
delete[] work;
84108
}
85109
}

source/source_base/inverse_matrix.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ class Inverse_Matrix_Complex
2020
void init( const int &dim_in);
2121

2222
private:
23-
int lwork;
2423
int dim=0;
2524
double *e=nullptr;
25+
int lwork=0;
26+
std::complex<double> *work2=nullptr;
27+
double* rwork=nullptr;
28+
int info=0;
2629
bool allocate=false; //mohan add 2012-04-02
2730

2831
ModuleBase::ComplexMatrix EA;

source/source_base/kernels/cuda/math_kernel_op.cu

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ __global__ void matrix_copy_kernel(const int n1, const int n2, const T* A, const
133133
}
134134
}
135135

136+
template <typename T, typename Real>
137+
__global__ void matrix_multiply_vector_kernel(const int m, const int n, T *a, const int lda, const Real *b, const Real alpha, T *c, const int ldc){
138+
int row = blockIdx.x * blockDim.x + threadIdx.x;
139+
int col = blockIdx.y * blockDim.y + threadIdx.y;
140+
if (col >= n || row >= m) return;
141+
c[col * ldc + row] = a[col * lda + row] * b[col] * alpha;
142+
}
143+
136144
cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
137145
{
138146
if (trans == 'N')
@@ -147,7 +155,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
147155
{
148156
return CUBLAS_OP_C;
149157
}
150-
else
158+
else
151159
{
152160
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
153161
}
@@ -438,10 +446,44 @@ void matrixCopy<std::complex<double>, base_device::DEVICE_GPU>::operator()(const
438446
cudaCheckOnDebug();
439447
}
440448

449+
template <>
450+
void matrix_mul_vector_op<double, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
451+
double *a, const int &lda, const double *b, const double alpha, double *c, const int &ldc){
452+
dim3 thread(16, 16, 1);
453+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
454+
matrix_multiply_vector_kernel<double, double> <<<block, thread >>>(m, n, a, lda,
455+
b, alpha, c, ldc);
456+
cudaCheckOnDebug();
457+
}
458+
459+
template <>
460+
void matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
461+
std::complex<float> *a, const int &lda, const float *b, const float alpha, std::complex<float> *c, const int &ldc){
462+
dim3 thread(16, 16, 1);
463+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
464+
matrix_multiply_vector_kernel<thrust::complex<float>, float> <<<block, thread >>>(m, n, reinterpret_cast<thrust::complex<float>*>(a), lda,
465+
b, alpha, reinterpret_cast<thrust::complex<float>*>(c), ldc);
466+
cudaCheckOnDebug();
467+
}
468+
469+
template <>
470+
void matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
471+
std::complex<double> *a, const int &lda, const double *b, const double alpha, std::complex<double> *c, const int &ldc)
472+
{
473+
dim3 thread(16, 16, 1);
474+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
475+
matrix_multiply_vector_kernel<thrust::complex<double>, double> <<<block, thread >>>(m, n, reinterpret_cast<thrust::complex<double>*>(a), lda,
476+
b, alpha, reinterpret_cast<thrust::complex<double>*>(c), ldc);
477+
cudaCheckOnDebug();
478+
}
441479

442480
// Explicitly instantiate functors for the types of functor registered.
443481

444482
template struct matrixCopy<std::complex<float>, base_device::DEVICE_GPU>;
445483
template struct matrixCopy<double, base_device::DEVICE_GPU>;
446484
template struct matrixCopy<std::complex<double>, base_device::DEVICE_GPU>;
485+
486+
template struct matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
487+
template struct matrix_mul_vector_op<double, base_device::DEVICE_GPU>;
488+
template struct matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
447489
} // namespace ModuleBase

0 commit comments

Comments
 (0)