1
- #include " custom_calls.h"
2
-
3
- #include " Eigen/Dense"
4
- #include " Eigen/Eigenvalues"
5
- #include " Eigen/QR"
6
- #include " exla_nif_util.h"
7
1
#include " xla/service/custom_call_target_registry.h"
8
2
9
- template <typename DataType>
10
- void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
11
- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
12
-
13
- // Map the input matrix
14
- Eigen::Map<RowMajorMatrix> input (in, m, n);
15
-
16
- // Compute the Eigenvalue decomposition
17
- Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver (input);
18
-
19
- if (eigensolver.info () != Eigen::Success) {
20
- std::cerr << " Eigenvalue decomposition failed!" << std::endl;
21
- return ;
22
- }
23
-
24
- // Get the eigenvalues and eigenvectors
25
- Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues = eigensolver.eigenvalues ();
26
- RowMajorMatrix eigenvectors = eigensolver.eigenvectors ();
27
-
28
- // Copy the eigenvalues to the output
29
- std::memcpy (eigenvalues_out, eigenvalues.data (), m * sizeof (DataType));
30
-
31
- // Copy the eigenvectors to the output
32
- std::memcpy (eigenvectors_out, eigenvectors.data (), m * n * sizeof (DataType));
33
- }
34
-
35
- template <typename DataType>
36
- void single_matrix_qr_cpu_custom_call (DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) {
37
- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
38
-
39
- Eigen::Map<RowMajorMatrix> input (in, m, n);
40
- Eigen::HouseholderQR<RowMajorMatrix> qr = input.householderQr ();
41
-
42
- RowMajorMatrix Q, R;
43
- size_t num_bytes_q, num_bytes_r;
44
-
45
- if (complete) {
46
- Q = qr.householderQ () * RowMajorMatrix::Identity (m, m);
47
- R = qr.matrixQR ();
48
-
49
- num_bytes_q = m * m * sizeof (DataType);
50
-
51
- for (uint64_t i = 0 ; i < m; ++i) {
52
- for (uint64_t j = 0 ; j < n; ++j) {
53
- r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
54
- }
55
- }
56
- } else {
57
- Q = qr.householderQ () * RowMajorMatrix::Identity (m, k);
58
- R = qr.matrixQR ().topRows (k);
59
-
60
- num_bytes_q = m * k * sizeof (DataType);
61
-
62
- for (uint64_t i = 0 ; i < k; ++i) {
63
- for (uint64_t j = 0 ; j < n; ++j) {
64
- r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
65
- }
66
- }
67
- }
68
-
69
- memcpy (q_out, Q.data (), num_bytes_q);
70
- }
71
-
72
- template <typename DataType>
73
- void qr_cpu_custom_call (void *out[], const void *in[]) {
74
- DataType *operand = (DataType *)in[0 ];
75
-
76
- uint64_t *dim_sizes = (uint64_t *)in[1 ];
77
- uint64_t num_operand_dims = dim_sizes[0 ];
78
- uint64_t num_q_dims = dim_sizes[1 ];
79
- uint64_t num_r_dims = dim_sizes[2 ];
80
-
81
- uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
82
- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
83
-
84
- uint64_t *q_dims_ptr = (uint64_t *)in[3 ];
85
- std::vector<uint64_t > q_dims (q_dims_ptr, q_dims_ptr + num_q_dims);
86
-
87
- uint64_t *r_dims_ptr = (uint64_t *)in[4 ];
88
- std::vector<uint64_t > r_dims (r_dims_ptr, r_dims_ptr + num_r_dims);
89
-
90
- uint64_t m = q_dims[q_dims.size () - 2 ];
91
- uint64_t k = q_dims[q_dims.size () - 1 ];
92
- uint64_t n = r_dims[r_dims.size () - 1 ];
93
- bool complete = r_dims[r_dims.size () - 2 ] == m;
94
-
95
- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
96
-
97
- uint64_t batch_items = 1 ;
98
- for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
99
- batch_items *= leading_dimensions[i];
100
- }
3
+ void qr_cpu_custom_call_f32 (void *out[], const void *in[]);
4
+ void qr_cpu_custom_call_f64 (void *out[], const void *in[]);
5
+ void qr_cpu_custom_call_f16 (void *out[], const void *in[]);
6
+ void qr_cpu_custom_call_bf16 (void *out[], const void *in[]);
7
+ void eigh_cpu_custom_call_f32 (void *out[], const void *in[]);
8
+ void eigh_cpu_custom_call_f64 (void *out[], const void *in[]);
101
9
102
- DataType *q = (DataType *)out[0 ];
103
- DataType *r = (DataType *)out[1 ];
104
-
105
- uint64_t r_stride = r_dims[r_dims.size () - 1 ] * r_dims[r_dims.size () - 2 ] * sizeof (DataType);
106
- uint64_t q_stride = q_dims[q_dims.size () - 1 ] * q_dims[q_dims.size () - 2 ] * sizeof (DataType);
107
- uint64_t inner_stride = m * n * sizeof (DataType);
108
-
109
- for (uint64_t i = 0 ; i < batch_items; i++) {
110
- single_matrix_qr_cpu_custom_call<DataType>(
111
- (DataType *)out[0 ] + i * q_stride,
112
- (DataType *)out[1 ] + i * r_stride,
113
- operand + i * inner_stride * sizeof (DataType),
114
- m, k, n, complete);
115
- }
116
- }
117
-
118
- template <typename DataType>
119
- void eigh_cpu_custom_call (void *out[], const void *in[]) {
120
- DataType *operand = (DataType *)in[0 ];
121
-
122
- uint64_t *dim_sizes = (uint64_t *)in[1 ];
123
- uint64_t num_operand_dims = dim_sizes[0 ];
124
- uint64_t num_eigenvalues_dims = dim_sizes[1 ];
125
- uint64_t num_eigenvectors_dims = dim_sizes[2 ];
126
-
127
- uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
128
- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
129
-
130
- uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3 ];
131
- std::vector<uint64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
132
-
133
- uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4 ];
134
- std::vector<uint64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
135
-
136
- uint64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
137
- uint64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
138
-
139
- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
140
-
141
- uint64_t batch_items = 1 ;
142
- for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
143
- batch_items *= leading_dimensions[i];
144
- }
145
-
146
- DataType *eigenvalues = (DataType *)out[0 ];
147
- DataType *eigenvectors = (DataType *)out[1 ];
148
-
149
- uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
150
- uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
151
- uint64_t inner_stride = m * n * sizeof (DataType);
152
-
153
- for (uint64_t i = 0 ; i < batch_items; i++) {
154
- single_matrix_eigh_cpu_custom_call<DataType>(
155
- eigenvalues + i * eigenvalues_stride,
156
- eigenvectors + i * eigenvectors_stride,
157
- operand + i * inner_stride / sizeof (DataType),
158
- m, n);
159
- }
160
- }
161
-
162
- void qr_cpu_custom_call_bf16 (void *out[], const void *in[]) {
163
- qr_cpu_custom_call<exla::bfloat16>(out, in);
164
- }
165
-
166
- void qr_cpu_custom_call_f16 (void *out[], const void *in[]) {
167
- qr_cpu_custom_call<exla::float16>(out, in);
168
- }
169
-
170
- void qr_cpu_custom_call_f32 (void *out[], const void *in[]) {
171
- qr_cpu_custom_call<float >(out, in);
172
- }
173
-
174
- void qr_cpu_custom_call_f64 (void *out[], const void *in[]) {
175
- qr_cpu_custom_call<double >(out, in);
176
- }
177
-
178
- void eigh_cpu_custom_call_f32 (void *out[], const void *in[]) {
179
- eigh_cpu_custom_call<float >(out, in);
180
- }
181
-
182
- void eigh_cpu_custom_call_f64 (void *out[], const void *in[]) {
183
- eigh_cpu_custom_call<double >(out, in);
184
- }
185
-
186
- XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f32" , qr_cpu_custom_call_f32);
187
10
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f64" , qr_cpu_custom_call_f64);
11
+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f32" , qr_cpu_custom_call_f32);
188
12
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f16" , qr_cpu_custom_call_f16);
189
13
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_bf16" , qr_cpu_custom_call_bf16);
190
-
191
-
192
- XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f32" , eigh_cpu_custom_call_f32);
193
14
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f64" , eigh_cpu_custom_call_f64);
15
+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f32" , eigh_cpu_custom_call_f32);
0 commit comments