Skip to content

Commit 5f05bf4

Browse files
committed
Add test for heevx
1 parent abf8340 commit 5f05bf4

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

source/source_base/module_container/ATen/kernels/cuda/lapack.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <cuda_runtime.h>
77
#include <thrust/complex.h>
88

9+
#include <cassert>
10+
11+
912
namespace container {
1013
namespace kernels {
1114

@@ -112,6 +115,7 @@ struct lapack_heevx<T, DEVICE_GPU> {
112115
Real *d_eigen_val,
113116
T *d_eigen_vec)
114117
{
118+
assert(n <= lda);
115119
// copy d_Mat to d_eigen_vec, and results will be overwritten into d_eigen_vec
116120
// by cuSolver
117121
cudaErrcheck(cudaMemcpy(d_eigen_vec, d_Mat, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice));

source/source_base/module_container/ATen/kernels/test/lapack_test.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,59 @@ TYPED_TEST(LapackTest, heevd) {
138138
EXPECT_EQ(expected_C1, expected_C2);
139139
}
140140

141+
TYPED_TEST(LapackTest, heevx) {
142+
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
143+
using Real = typename GetTypeReal<Type>::type;
144+
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
145+
146+
blas_gemm<Type, Device> gemmCalculator;
147+
blas_axpy<Type, Device> axpyCalculator;
148+
lapack_heevx<Type, Device> heevxCalculator;
149+
150+
const int dim = 3;
151+
const int neig = 2; // Compute first 2 eigenvalues
152+
153+
Tensor A = std::move(Tensor({static_cast<Type>(4.0), static_cast<Type>(1.0), static_cast<Type>(1.0),
154+
static_cast<Type>(1.0), static_cast<Type>(5.0), static_cast<Type>(3.0),
155+
static_cast<Type>(1.0), static_cast<Type>(3.0), static_cast<Type>(6.0)}).to_device<Device>());
156+
157+
Tensor E = std::move(Tensor({static_cast<Real>(0.0), static_cast<Real>(0.0)}).to_device<Device>());
158+
Tensor V = A;
159+
Tensor expected_C1 = std::move(Tensor({static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(0.0),
160+
static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(0.0)}).to_device<Device>());
161+
Tensor expected_C2 = expected_C1;
162+
expected_C1.zero();
163+
expected_C2.zero();
164+
165+
const char trans = 'N';
166+
const int m = 3;
167+
const int n = neig;
168+
const int k = 3;
169+
const Type alpha = static_cast<Type>(1.0);
170+
const Type beta = static_cast<Type>(0.0);
171+
172+
// Compute first neig eigenvalues and eigenvectors using heevx
173+
heevxCalculator(dim, dim, A.data<Type>(), neig, E.data<Real>(), V.data<Type>());
174+
175+
E = E.to_device<ct::DEVICE_CPU>();
176+
const Tensor Alpha = std::move(Tensor({
177+
static_cast<Type>(E.data<Real>()[0]),
178+
static_cast<Type>(E.data<Real>()[1])}));
179+
180+
// Check the eigenvalues and eigenvectors
181+
// A * x = lambda * x for the first neig eigenvectors
182+
// get A*V
183+
gemmCalculator(trans, trans, m, n, k, &alpha, A.data<Type>(), m, V.data<Type>(), k, &beta, expected_C1.data<Type>(), m);
184+
// get E*V
185+
for (int ii = 0; ii < neig; ii++) {
186+
axpyCalculator(dim, Alpha.data<Type>() + ii, V.data<Type>() + ii * dim, 1, expected_C2.data<Type>() + ii * dim, 1);
187+
}
188+
// check that A*V = E*V
189+
E = E.to_device<DEVICE_CPU>();
190+
V = V.to_device<DEVICE_CPU>();
191+
192+
EXPECT_EQ(expected_C1, expected_C2);
193+
}
141194

142195
TYPED_TEST(LapackTest, hegvd) {
143196
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
@@ -189,5 +242,7 @@ TYPED_TEST(LapackTest, hegvd) {
189242
EXPECT_EQ(expected_C1, expected_C2);
190243
}
191244

192-
} // namespace op
245+
246+
247+
} // namespace kernels
193248
} // namespace container

0 commit comments

Comments
 (0)