Skip to content

Commit 0e6c42c

Browse files
committed
Core algorithm: RT-TD now has preliminary support for GPU computation
1 parent d732808 commit 0e6c42c

File tree

27 files changed

+1237
-554
lines changed

27 files changed

+1237
-554
lines changed

source/module_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,49 @@ struct lapack_dngvd<T, DEVICE_GPU> {
117117
}
118118
};
119119

120+
template <typename T>
121+
struct lapack_getrf<T, DEVICE_GPU> {
122+
void operator()(
123+
const int& m,
124+
const int& n,
125+
T* Mat,
126+
const int& lda,
127+
int* ipiv)
128+
{
129+
cuSolverConnector::getrf(cusolver_handle, m, n, Mat, lda, ipiv);
130+
}
131+
};
132+
133+
template <typename T>
134+
struct lapack_getri<T, DEVICE_GPU> {
135+
void operator()(
136+
const int& n,
137+
T* Mat,
138+
const int& lda,
139+
const int* ipiv,
140+
T* work,
141+
const int& lwork)
142+
{
143+
throw std::runtime_error("cuSOLVER does not provide LU-based matrix inversion interface (getri). To compute the inverse on GPU, use getrs instead.");
144+
}
145+
};
146+
147+
template <typename T>
148+
struct lapack_getrs<T, DEVICE_GPU> {
149+
void operator()(
150+
const char& trans,
151+
const int& n,
152+
const int& nrhs,
153+
T* A,
154+
const int& lda,
155+
const int* ipiv,
156+
T* B,
157+
const int& ldb)
158+
{
159+
cuSolverConnector::getrs(cusolver_handle, trans, n, nrhs, A, lda, ipiv, B, ldb);
160+
}
161+
};
162+
120163
template struct set_matrix<float, DEVICE_GPU>;
121164
template struct set_matrix<double, DEVICE_GPU>;
122165
template struct set_matrix<std::complex<float>, DEVICE_GPU>;
@@ -142,5 +185,20 @@ template struct lapack_dngvd<double, DEVICE_GPU>;
142185
template struct lapack_dngvd<std::complex<float>, DEVICE_GPU>;
143186
template struct lapack_dngvd<std::complex<double>, DEVICE_GPU>;
144187

188+
template struct lapack_getrf<float, DEVICE_GPU>;
189+
template struct lapack_getrf<double, DEVICE_GPU>;
190+
template struct lapack_getrf<std::complex<float>, DEVICE_GPU>;
191+
template struct lapack_getrf<std::complex<double>, DEVICE_GPU>;
192+
193+
template struct lapack_getri<float, DEVICE_GPU>;
194+
template struct lapack_getri<double, DEVICE_GPU>;
195+
template struct lapack_getri<std::complex<float>, DEVICE_GPU>;
196+
template struct lapack_getri<std::complex<double>, DEVICE_GPU>;
197+
198+
template struct lapack_getrs<float, DEVICE_GPU>;
199+
template struct lapack_getrs<double, DEVICE_GPU>;
200+
template struct lapack_getrs<std::complex<float>, DEVICE_GPU>;
201+
template struct lapack_getrs<std::complex<double>, DEVICE_GPU>;
202+
145203
} // namespace kernels
146204
} // namespace container

source/module_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ struct lapack_getrf<T, DEVICE_CPU> {
131131
const int& n,
132132
T* Mat,
133133
const int& lda,
134-
int* ipiv,
135-
int& info)
134+
int* ipiv)
136135
{
136+
int info = 0;
137137
lapackConnector::getrf(m, n, Mat, lda, ipiv, info);
138138
if (info != 0) {
139139
throw std::runtime_error("getrf failed with info = " + std::to_string(info));
@@ -149,16 +149,36 @@ struct lapack_getri<T, DEVICE_CPU> {
149149
const int& lda,
150150
const int* ipiv,
151151
T* work,
152-
const int& lwork,
153-
int& info)
152+
const int& lwork)
154153
{
154+
int info = 0;
155155
lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info);
156156
if (info != 0) {
157157
throw std::runtime_error("getri failed with info = " + std::to_string(info));
158158
}
159159
}
160160
};
161161

162+
template <typename T>
163+
struct lapack_getrs<T, DEVICE_CPU> {
164+
void operator()(
165+
const char& trans,
166+
const int& n,
167+
const int& nrhs,
168+
T* A,
169+
const int& lda,
170+
const int* ipiv,
171+
T* B,
172+
const int& ldb)
173+
{
174+
int info = 0;
175+
lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info);
176+
if (info != 0) {
177+
throw std::runtime_error("getrs failed with info = " + std::to_string(info));
178+
}
179+
}
180+
};
181+
162182
template struct set_matrix<float, DEVICE_CPU>;
163183
template struct set_matrix<double, DEVICE_CPU>;
164184
template struct set_matrix<std::complex<float>, DEVICE_CPU>;
@@ -194,5 +214,10 @@ template struct lapack_getri<double, DEVICE_CPU>;
194214
template struct lapack_getri<std::complex<float>, DEVICE_CPU>;
195215
template struct lapack_getri<std::complex<double>, DEVICE_CPU>;
196216

217+
template struct lapack_getrs<float, DEVICE_CPU>;
218+
template struct lapack_getrs<double, DEVICE_CPU>;
219+
template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
220+
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
221+
197222
} // namespace kernels
198223
} // namespace container

source/module_base/module_container/ATen/kernels/lapack.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ struct lapack_getrf {
7373
const int& n,
7474
T* Mat,
7575
const int& lda,
76-
int* ipiv,
77-
int& info);
76+
int* ipiv);
7877
};
7978

8079

@@ -86,10 +85,21 @@ struct lapack_getri {
8685
const int& lda,
8786
const int* ipiv,
8887
T* work,
89-
const int& lwork,
90-
int& info);
88+
const int& lwork);
9189
};
9290

91+
template <typename T, typename Device>
92+
struct lapack_getrs {
93+
void operator()(
94+
const char& trans,
95+
const int& n,
96+
const int& nrhs,
97+
T* A,
98+
const int& lda,
99+
const int* ipiv,
100+
T* B,
101+
const int& ldb);
102+
};
93103

94104
#if defined(__CUDA) || defined(__ROCM)
95105
// TODO: Use C++ singleton to manage the GPU handles

0 commit comments

Comments
 (0)