Skip to content

Commit 11c19c8

Browse files
committed
Add test for hegvx
1 parent 5f05bf4 commit 11c19c8

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

source/source_base/module_container/ATen/kernels/test/lapack_test.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ TYPED_TEST(LapackTest, heevx) {
188188
// check that A*V = E*V
189189
E = E.to_device<DEVICE_CPU>();
190190
V = V.to_device<DEVICE_CPU>();
191+
std::cout << "Eigenvalues E: ";
192+
for (int i = 0; i < neig; i++) {
193+
std::cout << E.data<Real>()[i] << " ";
194+
}
195+
std::cout << std::endl;
196+
197+
std::cout << "Eigenvectors V:" << std::endl;
198+
for (int i = 0; i < dim; i++) {
199+
for (int j = 0; j < neig; j++) {
200+
std::cout << V.data<Type>()[i + j * dim] << " ";
201+
}
202+
std::cout << std::endl;
203+
}
191204

192205
EXPECT_EQ(expected_C1, expected_C2);
193206
}
@@ -242,7 +255,69 @@ TYPED_TEST(LapackTest, hegvd) {
242255
EXPECT_EQ(expected_C1, expected_C2);
243256
}
244257

258+
TYPED_TEST(LapackTest, hegvx) {
259+
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
260+
using Real = typename GetTypeReal<Type>::type;
261+
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
245262

263+
blas_gemm<Type, Device> gemmCalculator;
264+
blas_axpy<Type, Device> axpyCalculator;
265+
lapack_hegvx<Type, Device> hegvxCalculator;
266+
267+
const int dim = 3;
268+
const int neig = 2; // Compute first 2 eigenvalues
269+
270+
Tensor A = std::move(Tensor({static_cast<Type>(4.0), static_cast<Type>(1.0), static_cast<Type>(1.0),
271+
static_cast<Type>(1.0), static_cast<Type>(5.0), static_cast<Type>(3.0),
272+
static_cast<Type>(1.0), static_cast<Type>(3.0), static_cast<Type>(6.0)}).to_device<Device>());
273+
274+
Tensor B = std::move(Tensor({static_cast<Type>(2.0), static_cast<Type>(0.0), static_cast<Type>(0.0),
275+
static_cast<Type>(0.0), static_cast<Type>(2.0), static_cast<Type>(0.0),
276+
static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(2.0)}).to_device<Device>());
277+
278+
Tensor E = std::move(Tensor({static_cast<Real>(0.0), static_cast<Real>(0.0)}).to_device<Device>());
279+
Tensor V = A;
280+
Tensor expected_C1 = std::move(Tensor({static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(0.0),
281+
static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(0.0)}).to_device<Device>());
282+
Tensor expected_C2 = expected_C1;
283+
Tensor C_temp = expected_C1;
284+
expected_C1.zero();
285+
expected_C2.zero();
286+
287+
const char trans = 'N';
288+
const int m = 3;
289+
const int n = neig;
290+
const int k = 3;
291+
const Type alpha = static_cast<Type>(1.0);
292+
const Type beta = static_cast<Type>(0.0);
293+
294+
// Compute first neig eigenvalues and eigenvectors using hegvx
295+
hegvxCalculator(dim, dim, A.data<Type>(), B.data<Type>(), neig, E.data<Real>(), V.data<Type>());
296+
297+
E = E.to_device<ct::DEVICE_CPU>();
298+
const Tensor Alpha = std::move(Tensor({
299+
static_cast<Type>(E.data<Real>()[0]),
300+
static_cast<Type>(E.data<Real>()[1])}));
301+
302+
// Check the eigenvalues and eigenvectors
303+
// A * x = lambda * B * x for the first neig eigenvectors
304+
// get A*V
305+
gemmCalculator(trans, trans, m, n, k, &alpha, A.data<Type>(), m, V.data<Type>(), k, &beta, expected_C1.data<Type>(), m);
306+
// get E * B * V
307+
// where B is 2 * eye(3,3)
308+
// get C_temp = B * V first
309+
gemmCalculator(trans, trans, m, n, k, &alpha, B.data<Type>(), m, V.data<Type>(), k, &beta, C_temp.data<Type>(), m);
310+
// then compute C2 = E * B * V
311+
for (int ii = 0; ii < neig; ii++) {
312+
axpyCalculator(dim, Alpha.data<Type>() + ii, C_temp.data<Type>() + ii * dim, 1, expected_C2.data<Type>() + ii * dim, 1);
313+
}
314+
// check that A*V = E*V
315+
E = E.to_device<DEVICE_CPU>();
316+
V = V.to_device<DEVICE_CPU>();
317+
318+
319+
EXPECT_EQ(expected_C1, expected_C2);
320+
}
246321

247322
} // namespace kernels
248323
} // namespace container

0 commit comments

Comments
 (0)