22
33#include " Eigen/Eigenvalues"
44
5+ #include < algorithm>
56#include < iostream>
7+ #include < numeric>
8+ #include < vector>
69
710template <typename DataType>
8- void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
9- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
11+ void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out,
12+ DataType *eigenvectors_out,
13+ DataType *in, uint64_t m, uint64_t n) {
14+ typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
15+ Eigen::RowMajor>
16+ RowMajorMatrix;
1017
1118 // Map the input matrix
1219 Eigen::Map<RowMajorMatrix> input (in, m, n);
@@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
2027 }
2128
2229 // Get the eigenvalues and eigenvectors
23- Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues = eigensolver.eigenvalues ();
30+ Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues =
31+ eigensolver.eigenvalues ();
2432 RowMajorMatrix eigenvectors = eigensolver.eigenvectors ();
2533
26- // Copy the eigenvalues to the output
27- std::memcpy (eigenvalues_out, eigenvalues.data (), m * sizeof (DataType));
34+ // Create a vector of indices and sort it based on eigenvalues in decreasing
35+ // order
36+ std::vector<int > indices (m);
37+ std::iota (indices.begin (), indices.end (), 0 );
38+ std::sort (indices.begin (), indices.end (), [&eigenvalues](int i, int j) {
39+ return std::abs (eigenvalues (i)) > std::abs (eigenvalues (j));
40+ });
41+
42+ // Sort eigenvalues and rearrange eigenvectors
43+ Eigen::Matrix<DataType, Eigen::Dynamic, 1 > sorted_eigenvalues (m);
44+ RowMajorMatrix sorted_eigenvectors (m, n);
45+ for (int i = 0 ; i < m; ++i) {
46+ sorted_eigenvalues (i) = eigenvalues (indices[i]);
47+ sorted_eigenvectors.col (i) = eigenvectors.col (indices[i]);
48+ }
49+
50+ // Copy the sorted eigenvalues to the output
51+ std::memcpy (eigenvalues_out, sorted_eigenvalues.data (), m * sizeof (DataType));
2852
29- // Copy the eigenvectors to the output
30- std::memcpy (eigenvectors_out, eigenvectors.data (), m * n * sizeof (DataType));
53+ // Copy the sorted eigenvectors to the output
54+ std::memcpy (eigenvectors_out, sorted_eigenvectors.data (),
55+ m * n * sizeof (DataType));
3156}
3257
3358template <typename DataType>
@@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
4065 uint64_t num_eigenvectors_dims = dim_sizes[2 ];
4166
4267 uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
43- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
68+ std::vector<uint64_t > operand_dims (operand_dims_ptr,
69+ operand_dims_ptr + num_operand_dims);
4470
4571 uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3 ];
46- std::vector<uint64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
72+ std::vector<uint64_t > eigenvalues_dims (
73+ eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
4774
4875 uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4 ];
49- std::vector<uint64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
76+ std::vector<uint64_t > eigenvectors_dims (
77+ eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
5078
5179 uint64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
5280 uint64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
5381
54- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
82+ auto leading_dimensions =
83+ std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
5584
5685 uint64_t batch_items = 1 ;
5786 for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
@@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
6190 DataType *eigenvalues = (DataType *)out[0 ];
6291 DataType *eigenvectors = (DataType *)out[1 ];
6392
64- uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
65- uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
66- uint64_t inner_stride = m * n * sizeof (DataType);
93+ uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ];
94+ uint64_t eigenvectors_stride =
95+ eigenvectors_dims[eigenvectors_dims.size () - 1 ] *
96+ eigenvectors_dims[eigenvectors_dims.size () - 2 ];
97+ uint64_t inner_stride = m * n;
6798
6899 for (uint64_t i = 0 ; i < batch_items; i++) {
69100 single_matrix_eigh_cpu_custom_call<DataType>(
70101 eigenvalues + i * eigenvalues_stride,
71- eigenvectors + i * eigenvectors_stride,
72- operand + i * inner_stride / sizeof (DataType),
73- m, n);
102+ eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
103+ n);
74104 }
75105}
0 commit comments