@@ -138,6 +138,59 @@ TYPED_TEST(LapackTest, heevd) {
138138 EXPECT_EQ (expected_C1, expected_C2);
139139}
140140
141+ TYPED_TEST (LapackTest, heevx) {
142+ using Type = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
143+ using Real = typename GetTypeReal<Type>::type;
144+ using Device = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
145+
146+ blas_gemm<Type, Device> gemmCalculator;
147+ blas_axpy<Type, Device> axpyCalculator;
148+ lapack_heevx<Type, Device> heevxCalculator;
149+
150+ const int dim = 3 ;
151+ const int neig = 2 ; // Compute first 2 eigenvalues
152+
153+ Tensor A = std::move (Tensor ({static_cast <Type>(4.0 ), static_cast <Type>(1.0 ), static_cast <Type>(1.0 ),
154+ static_cast <Type>(1.0 ), static_cast <Type>(5.0 ), static_cast <Type>(3.0 ),
155+ static_cast <Type>(1.0 ), static_cast <Type>(3.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
156+
157+ Tensor E = std::move (Tensor ({static_cast <Real>(0.0 ), static_cast <Real>(0.0 )}).to_device <Device>());
158+ Tensor V = A;
159+ Tensor expected_C1 = std::move (Tensor ({static_cast <Type>(0.0 ), static_cast <Type>(0.0 ), static_cast <Type>(0.0 ),
160+ static_cast <Type>(0.0 ), static_cast <Type>(0.0 ), static_cast <Type>(0.0 )}).to_device <Device>());
161+ Tensor expected_C2 = expected_C1;
162+ expected_C1.zero ();
163+ expected_C2.zero ();
164+
165+ const char trans = ' N' ;
166+ const int m = 3 ;
167+ const int n = neig;
168+ const int k = 3 ;
169+ const Type alpha = static_cast <Type>(1.0 );
170+ const Type beta = static_cast <Type>(0.0 );
171+
172+ // Compute first neig eigenvalues and eigenvectors using heevx
173+ heevxCalculator (dim, dim, A.data <Type>(), neig, E.data <Real>(), V.data <Type>());
174+
175+ E = E.to_device <ct::DEVICE_CPU>();
176+ const Tensor Alpha = std::move (Tensor ({
177+ static_cast <Type>(E.data <Real>()[0 ]),
178+ static_cast <Type>(E.data <Real>()[1 ])}));
179+
180+ // Check the eigenvalues and eigenvectors
181+ // A * x = lambda * x for the first neig eigenvectors
182+ // get A*V
183+ gemmCalculator (trans, trans, m, n, k, &alpha, A.data <Type>(), m, V.data <Type>(), k, &beta, expected_C1.data <Type>(), m);
184+ // get E*V
185+ for (int ii = 0 ; ii < neig; ii++) {
186+ axpyCalculator (dim, Alpha.data <Type>() + ii, V.data <Type>() + ii * dim, 1 , expected_C2.data <Type>() + ii * dim, 1 );
187+ }
188+ // check that A*V = E*V
189+ E = E.to_device <DEVICE_CPU>();
190+ V = V.to_device <DEVICE_CPU>();
191+
192+ EXPECT_EQ (expected_C1, expected_C2);
193+ }
141194
142195TYPED_TEST (LapackTest, hegvd) {
143196 using Type = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
@@ -189,5 +242,7 @@ TYPED_TEST(LapackTest, hegvd) {
189242 EXPECT_EQ (expected_C1, expected_C2);
190243}
191244
192- } // namespace op
245+
246+
247+ } // namespace kernels
193248} // namespace container
0 commit comments