Skip to content

Commit 7da4779

Browse files
committed
Sort and classify lapack routines
1 parent b639675 commit 7da4779

File tree

3 files changed

+180
-149
lines changed

3 files changed

+180
-149
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ struct set_matrix<T, DEVICE_GPU> {
6262
}
6363
};
6464

65+
66+
67+
// --- 1. Matrix Decomposition ---
6568
template <typename T>
6669
struct lapack_trtri<T, DEVICE_GPU> {
6770
void operator()(
@@ -90,6 +93,53 @@ struct lapack_potrf<T, DEVICE_GPU> {
9093
}
9194
};
9295

96+
template <typename T>
97+
struct lapack_getrf<T, DEVICE_GPU> {
98+
void operator()(
99+
const int& m,
100+
const int& n,
101+
T* Mat,
102+
const int& lda,
103+
int* ipiv)
104+
{
105+
cuSolverConnector::getrf(cusolver_handle, m, n, Mat, lda, ipiv);
106+
}
107+
};
108+
109+
template <typename T>
110+
struct lapack_getri<T, DEVICE_GPU> {
111+
void operator()(
112+
const int& n,
113+
T* Mat,
114+
const int& lda,
115+
const int* ipiv,
116+
T* work,
117+
const int& lwork)
118+
{
119+
throw std::runtime_error("cuSOLVER does not provide LU-based matrix inversion interface (getri). To compute the inverse on GPU, use getrs instead.");
120+
}
121+
};
122+
123+
124+
// --- 2. Linear System Solvers ---
125+
template <typename T>
126+
struct lapack_getrs<T, DEVICE_GPU> {
127+
void operator()(
128+
const char& trans,
129+
const int& n,
130+
const int& nrhs,
131+
T* A,
132+
const int& lda,
133+
const int* ipiv,
134+
T* B,
135+
const int& ldb)
136+
{
137+
cuSolverConnector::getrs(cusolver_handle, trans, n, nrhs, A, lda, ipiv, B, ldb);
138+
}
139+
};
140+
141+
142+
// --- 3. Standard & Generalized Eigenvalue ---
93143
template <typename T>
94144
struct lapack_heevd<T, DEVICE_GPU> {
95145
using Real = typename GetTypeReal<T>::type;
@@ -198,49 +248,6 @@ struct lapack_hegvx<T, DEVICE_GPU> {
198248

199249

200250

201-
template <typename T>
202-
struct lapack_getrf<T, DEVICE_GPU> {
203-
void operator()(
204-
const int& m,
205-
const int& n,
206-
T* Mat,
207-
const int& lda,
208-
int* ipiv)
209-
{
210-
cuSolverConnector::getrf(cusolver_handle, m, n, Mat, lda, ipiv);
211-
}
212-
};
213-
214-
template <typename T>
215-
struct lapack_getri<T, DEVICE_GPU> {
216-
void operator()(
217-
const int& n,
218-
T* Mat,
219-
const int& lda,
220-
const int* ipiv,
221-
T* work,
222-
const int& lwork)
223-
{
224-
throw std::runtime_error("cuSOLVER does not provide LU-based matrix inversion interface (getri). To compute the inverse on GPU, use getrs instead.");
225-
}
226-
};
227-
228-
template <typename T>
229-
struct lapack_getrs<T, DEVICE_GPU> {
230-
void operator()(
231-
const char& trans,
232-
const int& n,
233-
const int& nrhs,
234-
T* A,
235-
const int& lda,
236-
const int* ipiv,
237-
T* B,
238-
const int& ldb)
239-
{
240-
cuSolverConnector::getrs(cusolver_handle, trans, n, nrhs, A, lda, ipiv, B, ldb);
241-
}
242-
};
243-
244251
template struct set_matrix<float, DEVICE_GPU>;
245252
template struct set_matrix<double, DEVICE_GPU>;
246253
template struct set_matrix<std::complex<float>, DEVICE_GPU>;
@@ -256,6 +263,13 @@ template struct lapack_potrf<double, DEVICE_GPU>;
256263
template struct lapack_potrf<std::complex<float>, DEVICE_GPU>;
257264
template struct lapack_potrf<std::complex<double>, DEVICE_GPU>;
258265

266+
267+
template struct lapack_getrs<float, DEVICE_GPU>;
268+
template struct lapack_getrs<double, DEVICE_GPU>;
269+
template struct lapack_getrs<std::complex<float>, DEVICE_GPU>;
270+
template struct lapack_getrs<std::complex<double>, DEVICE_GPU>;
271+
272+
259273
template struct lapack_heevd<float, DEVICE_GPU>;
260274
template struct lapack_heevd<double, DEVICE_GPU>;
261275
template struct lapack_heevd<std::complex<float>, DEVICE_GPU>;
@@ -286,10 +300,7 @@ template struct lapack_getri<double, DEVICE_GPU>;
286300
template struct lapack_getri<std::complex<float>, DEVICE_GPU>;
287301
template struct lapack_getri<std::complex<double>, DEVICE_GPU>;
288302

289-
template struct lapack_getrs<float, DEVICE_GPU>;
290-
template struct lapack_getrs<double, DEVICE_GPU>;
291-
template struct lapack_getrs<std::complex<float>, DEVICE_GPU>;
292-
template struct lapack_getrs<std::complex<double>, DEVICE_GPU>;
303+
293304

294305
} // namespace kernels
295306
} // namespace container

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct set_matrix<T, DEVICE_CPU> {
4040
}
4141
};
4242

43+
// --- 1. Matrix Decomposition ---
4344
template <typename T>
4445
struct lapack_trtri<T, DEVICE_CPU> {
4546
void operator()(
@@ -73,6 +74,66 @@ struct lapack_potrf<T, DEVICE_CPU> {
7374
}
7475
};
7576

77+
78+
template <typename T>
79+
struct lapack_getrf<T, DEVICE_CPU> {
80+
void operator()(
81+
const int& m,
82+
const int& n,
83+
T* Mat,
84+
const int& lda,
85+
int* ipiv)
86+
{
87+
int info = 0;
88+
lapackConnector::getrf(m, n, Mat, lda, ipiv, info);
89+
if (info != 0) {
90+
throw std::runtime_error("getrf failed with info = " + std::to_string(info));
91+
}
92+
}
93+
};
94+
95+
template <typename T>
96+
struct lapack_getri<T, DEVICE_CPU> {
97+
void operator()(
98+
const int& n,
99+
T* Mat,
100+
const int& lda,
101+
const int* ipiv,
102+
T* work,
103+
const int& lwork)
104+
{
105+
int info = 0;
106+
lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info);
107+
if (info != 0) {
108+
throw std::runtime_error("getri failed with info = " + std::to_string(info));
109+
}
110+
}
111+
};
112+
113+
114+
// --- 2. Linear System Solvers ---
115+
template <typename T>
116+
struct lapack_getrs<T, DEVICE_CPU> {
117+
void operator()(
118+
const char& trans,
119+
const int& n,
120+
const int& nrhs,
121+
T* A,
122+
const int& lda,
123+
const int* ipiv,
124+
T* B,
125+
const int& ldb)
126+
{
127+
int info = 0;
128+
lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info);
129+
if (info != 0) {
130+
throw std::runtime_error("getrs failed with info = " + std::to_string(info));
131+
}
132+
}
133+
};
134+
135+
136+
// --- 3. Standard & Generalized Eigenvalue ---
76137
template <typename T>
77138
struct lapack_heevd<T, DEVICE_CPU> {
78139
using Real = typename GetTypeReal<T>::type;
@@ -338,60 +399,9 @@ struct lapack_hegvx<T, DEVICE_CPU> {
338399
}
339400
};
340401

341-
template <typename T>
342-
struct lapack_getrf<T, DEVICE_CPU> {
343-
void operator()(
344-
const int& m,
345-
const int& n,
346-
T* Mat,
347-
const int& lda,
348-
int* ipiv)
349-
{
350-
int info = 0;
351-
lapackConnector::getrf(m, n, Mat, lda, ipiv, info);
352-
if (info != 0) {
353-
throw std::runtime_error("getrf failed with info = " + std::to_string(info));
354-
}
355-
}
356-
};
357402

358-
template <typename T>
359-
struct lapack_getri<T, DEVICE_CPU> {
360-
void operator()(
361-
const int& n,
362-
T* Mat,
363-
const int& lda,
364-
const int* ipiv,
365-
T* work,
366-
const int& lwork)
367-
{
368-
int info = 0;
369-
lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info);
370-
if (info != 0) {
371-
throw std::runtime_error("getri failed with info = " + std::to_string(info));
372-
}
373-
}
374-
};
375403

376-
template <typename T>
377-
struct lapack_getrs<T, DEVICE_CPU> {
378-
void operator()(
379-
const char& trans,
380-
const int& n,
381-
const int& nrhs,
382-
T* A,
383-
const int& lda,
384-
const int* ipiv,
385-
T* B,
386-
const int& ldb)
387-
{
388-
int info = 0;
389-
lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info);
390-
if (info != 0) {
391-
throw std::runtime_error("getrs failed with info = " + std::to_string(info));
392-
}
393-
}
394-
};
404+
395405

396406
template struct set_matrix<float, DEVICE_CPU>;
397407
template struct set_matrix<double, DEVICE_CPU>;
@@ -408,6 +418,24 @@ template struct lapack_trtri<double, DEVICE_CPU>;
408418
template struct lapack_trtri<std::complex<float>, DEVICE_CPU>;
409419
template struct lapack_trtri<std::complex<double>, DEVICE_CPU>;
410420

421+
422+
template struct lapack_getrf<float, DEVICE_CPU>;
423+
template struct lapack_getrf<double, DEVICE_CPU>;
424+
template struct lapack_getrf<std::complex<float>, DEVICE_CPU>;
425+
template struct lapack_getrf<std::complex<double>, DEVICE_CPU>;
426+
427+
template struct lapack_getri<float, DEVICE_CPU>;
428+
template struct lapack_getri<double, DEVICE_CPU>;
429+
template struct lapack_getri<std::complex<float>, DEVICE_CPU>;
430+
template struct lapack_getri<std::complex<double>, DEVICE_CPU>;
431+
432+
433+
template struct lapack_getrs<float, DEVICE_CPU>;
434+
template struct lapack_getrs<double, DEVICE_CPU>;
435+
template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
436+
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
437+
438+
411439
template struct lapack_heevd<float, DEVICE_CPU>;
412440
template struct lapack_heevd<double, DEVICE_CPU>;
413441
template struct lapack_heevd<std::complex<float>, DEVICE_CPU>;
@@ -428,20 +456,5 @@ template struct lapack_hegvx<double, DEVICE_CPU>;
428456
template struct lapack_hegvx<std::complex<float>, DEVICE_CPU>;
429457
template struct lapack_hegvx<std::complex<double>, DEVICE_CPU>;
430458

431-
template struct lapack_getrf<float, DEVICE_CPU>;
432-
template struct lapack_getrf<double, DEVICE_CPU>;
433-
template struct lapack_getrf<std::complex<float>, DEVICE_CPU>;
434-
template struct lapack_getrf<std::complex<double>, DEVICE_CPU>;
435-
436-
template struct lapack_getri<float, DEVICE_CPU>;
437-
template struct lapack_getri<double, DEVICE_CPU>;
438-
template struct lapack_getri<std::complex<float>, DEVICE_CPU>;
439-
template struct lapack_getri<std::complex<double>, DEVICE_CPU>;
440-
441-
template struct lapack_getrs<float, DEVICE_CPU>;
442-
template struct lapack_getrs<double, DEVICE_CPU>;
443-
template struct lapack_getrs<std::complex<float>, DEVICE_CPU>;
444-
template struct lapack_getrs<std::complex<double>, DEVICE_CPU>;
445-
446459
} // namespace kernels
447460
} // namespace container

0 commit comments

Comments
 (0)