Skip to content

Commit 1ae41a1

Browse files
committed
Clean the code and add docs
1 parent fe49ac8 commit 1ae41a1

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

source/source_base/module_container/ATen/kernels/lapack.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <base/third_party/lapack.h>
55

6-
// #include <cstring> // std::memcpy
76
#include <algorithm> // std::copy
87
#include <complex>
98
#include <stdexcept>
@@ -208,10 +207,6 @@ struct lapack_hegvd<T, DEVICE_CPU> {
208207
// first copy Mat_A to eigen_vec
209208
// then pass as argument "A" in lapack hegvd
210209
// and this block of memory will be overwritten by eigenvectors
211-
// for (int i = 0; i < dim * lda; ++i){
212-
// eigen_vec[i] = Mat_A[i];
213-
// }
214-
// std::memcpy(eigen_vec, Mat_A, sizeof(T) * dim * lda);
215210
// eigen_vec = Mat_A
216211
std::copy(Mat_A, Mat_A + dim*lda, eigen_vec);
217212

source/source_base/module_container/ATen/kernels/lapack.h

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,18 @@ struct lapack_potrf {
4040
const int& lda);
4141
};
4242

43-
43+
// ============================================================================
44+
// Standard Hermitian Eigenvalue Problem Solvers
45+
// ============================================================================
46+
// The following structures (lapack_heevd and lapack_heevx) implement solvers
47+
// for standard Hermitian eigenvalue problems of the form:
48+
// A * x = lambda * x
49+
// where:
50+
// - A is a Hermitian matrix
51+
// - lambda are the eigenvalues to be computed
52+
// - x are the corresponding eigenvectors
53+
//
54+
// ============================================================================
4455
template <typename T, typename Device>
4556
struct lapack_heevd {
4657
using Real = typename GetTypeReal<T>::type;
@@ -61,18 +72,15 @@ struct lapack_heevx {
6172
* This function solves the problem A*x = lambda*x, where A is a Hermitian matrix.
6273
* It computes a subset of eigenvalues and, optionally, the corresponding eigenvectors.
6374
*
64-
* @param jobz 'N': Compute eigenvalues only; 'V': Compute eigenvalues and eigenvectors.
65-
* @param range 'A': All eigenvalues; 'V': Eigenvalues in the half-open interval (vl, vu]; 'I': Eigenvalues with indices il through iu.
66-
* @param uplo 'U': Upper triangle of A is stored; 'L': Lower triangle is stored.
6775
* @param dim The order of the matrix A. dim >= 0.
68-
* @param Mat On entry, the Hermitian matrix A. On exit, it may be overwritten.
69-
* @param vl Lower bound of the interval to search for eigenvalues if range == 'V'.
70-
* @param vu Upper bound of the interval to search for eigenvalues if range == 'V'.
71-
* @param il Index of the smallest eigenvalue to be returned if range == 'I'.
72-
* @param iu Index of the largest eigenvalue to be returned if range == 'I'.
73-
* @param m Output: The total number of found eigenvalues.
74-
* @param eigen_val Array to store the computed eigenvalues in ascending order.
75-
* @param eigen_vec If not nullptr and jobz == 'V', array to store the computed eigenvectors.
76+
* @param lda The leading dimension of the array Mat. lda >= max(1, dim).
77+
* @param Mat On entry, the Hermitian matrix A. On exit, A is kept.
78+
* @param neig The number of eigenvalues to be found. 0 <= neig <= dim.
79+
* @param eigen_val On normal exit, the first \p neig elements contain the selected
80+
* eigenvalues in ascending order.
81+
* @param eigen_vec If eigen_vec is not nullptr, then on exit it contains the
82+
* orthonormal eigenvectors of the matrix A. The eigenvectors are stored in
83+
* the columns of eigen_vec, in the same order as the eigenvalues.
7684
*
7785
* @note
7886
* See LAPACK ZHEEVX or CHEEVX documentation for more details.
@@ -87,6 +95,21 @@ struct lapack_heevx {
8795
T *eigen_vec);
8896
};
8997

98+
99+
// ============================================================================
100+
// Generalized Hermitian-definite Eigenvalue Problem Solvers
101+
// ============================================================================
102+
// The following structures (lapack_hegvd and lapack_hegvx) implement solvers
103+
// for generalized Hermitian-definite eigenvalue problems of the form:
104+
// A * x = lambda * B * x
105+
// where:
106+
// - A is a Hermitian matrix
107+
// - B is a Hermitian positive definite matrix
108+
// - lambda are the eigenvalues to be computed
109+
// - x are the corresponding eigenvectors
110+
//
111+
// ============================================================================
112+
90113
template <typename T, typename Device>
91114
struct lapack_hegvd {
92115
using Real = typename GetTypeReal<T>::type;

source/source_base/module_container/ATen/kernels/test/lapack_test.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ TYPED_TEST(LapackTest, Potrf) {
9292
EXPECT_EQ(A, C);
9393
}
9494

95+
// Test for lapack_heevd and lapack_heevx:
96+
// Solve a standard eigenvalue problem
97+
// and check that A*V = V*E
9598
TYPED_TEST(LapackTest, heevd) {
9699
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
97100
using Real = typename GetTypeReal<Type>::type;
@@ -179,19 +182,19 @@ TYPED_TEST(LapackTest, heevx) {
179182

180183
// Check the eigenvalues and eigenvectors
181184
// A * x = lambda * x for the first neig eigenvectors
182-
// get A*V
185+
// check that A * V = V * E
186+
// get A * V
183187
gemmCalculator(trans, trans, m, n, k, &alpha, A.data<Type>(), m, V.data<Type>(), k, &beta, expected_C1.data<Type>(), m);
184-
// get E*V
188+
// get V * E
185189
for (int ii = 0; ii < neig; ii++) {
186190
axpyCalculator(dim, Alpha.data<Type>() + ii, V.data<Type>() + ii * dim, 1, expected_C2.data<Type>() + ii * dim, 1);
187-
}
188-
// check that A*V = E*V
189-
E = E.to_device<DEVICE_CPU>();
190-
V = V.to_device<DEVICE_CPU>();
191191

192192
EXPECT_EQ(expected_C1, expected_C2);
193193
}
194194

195+
// Test for lapack_hegvd and lapack_hegvx
196+
// Solve a generalized eigenvalue problem
197+
// and check that A * v = e * B * v
195198
TYPED_TEST(LapackTest, hegvd) {
196199
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
197200
using Real = typename GetTypeReal<Type>::type;
@@ -288,7 +291,8 @@ TYPED_TEST(LapackTest, hegvx) {
288291

289292
// Check the eigenvalues and eigenvectors
290293
// A * x = lambda * B * x for the first neig eigenvectors
291-
// get A*V
294+
// check that A * V = E * B * V
295+
// get A * V
292296
gemmCalculator(trans, trans, m, n, k, &alpha, A.data<Type>(), m, V.data<Type>(), k, &beta, expected_C1.data<Type>(), m);
293297
// get E * B * V
294298
// where B is 2 * eye(3,3)
@@ -298,10 +302,6 @@ TYPED_TEST(LapackTest, hegvx) {
298302
for (int ii = 0; ii < neig; ii++) {
299303
axpyCalculator(dim, Alpha.data<Type>() + ii, C_temp.data<Type>() + ii * dim, 1, expected_C2.data<Type>() + ii * dim, 1);
300304
}
301-
// check that A*V = E*V
302-
E = E.to_device<DEVICE_CPU>();
303-
V = V.to_device<DEVICE_CPU>();
304-
305305

306306
EXPECT_EQ(expected_C1, expected_C2);
307307
}

0 commit comments

Comments
 (0)