7
7
#include " xla/service/custom_call_target_registry.h"
8
8
9
9
template <typename DataType>
10
- void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) {
10
+ void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
11
11
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
12
12
13
13
// Map the input matrix
@@ -33,7 +33,7 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
33
33
}
34
34
35
35
template <typename DataType>
36
- void single_matrix_qr_cpu_custom_call (DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
36
+ void single_matrix_qr_cpu_custom_call (DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) {
37
37
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
38
38
39
39
Eigen::Map<RowMajorMatrix> input (in, m, n);
@@ -48,8 +48,8 @@ void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType
48
48
49
49
num_bytes_q = m * m * sizeof (DataType);
50
50
51
- for (int64_t i = 0 ; i < m; ++i) {
52
- for (int64_t j = 0 ; j < n; ++j) {
51
+ for (uint64_t i = 0 ; i < m; ++i) {
52
+ for (uint64_t j = 0 ; j < n; ++j) {
53
53
r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
54
54
}
55
55
}
@@ -59,8 +59,8 @@ void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType
59
59
60
60
num_bytes_q = m * k * sizeof (DataType);
61
61
62
- for (int64_t i = 0 ; i < k; ++i) {
63
- for (int64_t j = 0 ; j < n; ++j) {
62
+ for (uint64_t i = 0 ; i < k; ++i) {
63
+ for (uint64_t j = 0 ; j < n; ++j) {
64
64
r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
65
65
}
66
66
}
@@ -73,40 +73,40 @@ template <typename DataType>
73
73
void qr_cpu_custom_call (void *out[], const void *in[]) {
74
74
DataType *operand = (DataType *)in[0 ];
75
75
76
- int64_t *dim_sizes = (int64_t *)in[1 ];
77
- int64_t num_operand_dims = dim_sizes[0 ];
78
- int64_t num_q_dims = dim_sizes[1 ];
79
- int64_t num_r_dims = dim_sizes[2 ];
76
+ uint64_t *dim_sizes = (uint64_t *)in[1 ];
77
+ uint64_t num_operand_dims = dim_sizes[0 ];
78
+ uint64_t num_q_dims = dim_sizes[1 ];
79
+ uint64_t num_r_dims = dim_sizes[2 ];
80
80
81
- int64_t *operand_dims_ptr = (int64_t *)in[2 ];
82
- std::vector<int64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
81
+ uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
82
+ std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
83
83
84
- int64_t *q_dims_ptr = (int64_t *)in[3 ];
85
- std::vector<int64_t > q_dims (q_dims_ptr, q_dims_ptr + num_q_dims);
84
+ uint64_t *q_dims_ptr = (uint64_t *)in[3 ];
85
+ std::vector<uint64_t > q_dims (q_dims_ptr, q_dims_ptr + num_q_dims);
86
86
87
- int64_t *r_dims_ptr = (int64_t *)in[4 ];
88
- std::vector<int64_t > r_dims (r_dims_ptr, r_dims_ptr + num_r_dims);
87
+ uint64_t *r_dims_ptr = (uint64_t *)in[4 ];
88
+ std::vector<uint64_t > r_dims (r_dims_ptr, r_dims_ptr + num_r_dims);
89
89
90
- int64_t m = q_dims[q_dims.size () - 2 ];
91
- int64_t k = q_dims[q_dims.size () - 1 ];
92
- int64_t n = r_dims[r_dims.size () - 1 ];
90
+ uint64_t m = q_dims[q_dims.size () - 2 ];
91
+ uint64_t k = q_dims[q_dims.size () - 1 ];
92
+ uint64_t n = r_dims[r_dims.size () - 1 ];
93
93
bool complete = r_dims[r_dims.size () - 2 ] == m;
94
94
95
- auto leading_dimensions = std::vector<int64_t >(operand_dims.begin (), operand_dims.end () - 2 );
95
+ auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
96
96
97
- int64_t batch_items = 1 ;
98
- for (int64_t i = 0 ; i < leading_dimensions.size (); i++) {
97
+ uint64_t batch_items = 1 ;
98
+ for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
99
99
batch_items *= leading_dimensions[i];
100
100
}
101
101
102
102
DataType *q = (DataType *)out[0 ];
103
103
DataType *r = (DataType *)out[1 ];
104
104
105
- int64_t r_stride = r_dims[r_dims.size () - 1 ] * r_dims[r_dims.size () - 2 ] * sizeof (DataType);
106
- int64_t q_stride = q_dims[q_dims.size () - 1 ] * q_dims[q_dims.size () - 2 ] * sizeof (DataType);
107
- int64_t inner_stride = m * n * sizeof (DataType);
105
+ uint64_t r_stride = r_dims[r_dims.size () - 1 ] * r_dims[r_dims.size () - 2 ] * sizeof (DataType);
106
+ uint64_t q_stride = q_dims[q_dims.size () - 1 ] * q_dims[q_dims.size () - 2 ] * sizeof (DataType);
107
+ uint64_t inner_stride = m * n * sizeof (DataType);
108
108
109
- for (int64_t i = 0 ; i < batch_items; i++) {
109
+ for (uint64_t i = 0 ; i < batch_items; i++) {
110
110
single_matrix_qr_cpu_custom_call<DataType>(
111
111
(DataType *)out[0 ] + i * q_stride,
112
112
(DataType *)out[1 ] + i * r_stride,
@@ -119,38 +119,38 @@ template <typename DataType>
119
119
void eigh_cpu_custom_call (void *out[], const void *in[]) {
120
120
DataType *operand = (DataType *)in[0 ];
121
121
122
- int64_t *dim_sizes = (int64_t *)in[1 ];
123
- int64_t num_operand_dims = dim_sizes[0 ];
124
- int64_t num_eigenvalues_dims = dim_sizes[1 ];
125
- int64_t num_eigenvectors_dims = dim_sizes[2 ];
122
+ uint64_t *dim_sizes = (uint64_t *)in[1 ];
123
+ uint64_t num_operand_dims = dim_sizes[0 ];
124
+ uint64_t num_eigenvalues_dims = dim_sizes[1 ];
125
+ uint64_t num_eigenvectors_dims = dim_sizes[2 ];
126
126
127
- int64_t *operand_dims_ptr = (int64_t *)in[2 ];
128
- std::vector<int64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
127
+ uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
128
+ std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
129
129
130
- int64_t *eigenvalues_dims_ptr = (int64_t *)in[3 ];
131
- std::vector<int64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
130
+ uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3 ];
131
+ std::vector<uint64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
132
132
133
- int64_t *eigenvectors_dims_ptr = (int64_t *)in[4 ];
134
- std::vector<int64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
133
+ uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4 ];
134
+ std::vector<uint64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
135
135
136
- int64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
137
- int64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
136
+ uint64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
137
+ uint64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
138
138
139
- auto leading_dimensions = std::vector<int64_t >(operand_dims.begin (), operand_dims.end () - 2 );
139
+ auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
140
140
141
- int64_t batch_items = 1 ;
142
- for (int64_t i = 0 ; i < leading_dimensions.size (); i++) {
141
+ uint64_t batch_items = 1 ;
142
+ for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
143
143
batch_items *= leading_dimensions[i];
144
144
}
145
145
146
146
DataType *eigenvalues = (DataType *)out[0 ];
147
147
DataType *eigenvectors = (DataType *)out[1 ];
148
148
149
- int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
150
- int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
151
- int64_t inner_stride = m * n * sizeof (DataType);
149
+ uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
150
+ uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
151
+ uint64_t inner_stride = m * n * sizeof (DataType);
152
152
153
- for (int64_t i = 0 ; i < batch_items; i++) {
153
+ for (uint64_t i = 0 ; i < batch_items; i++) {
154
154
single_matrix_eigh_cpu_custom_call<DataType>(
155
155
eigenvalues + i * eigenvalues_stride,
156
156
eigenvectors + i * eigenvectors_stride,
@@ -190,4 +190,4 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_c
190
190
191
191
192
192
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f32" , eigh_cpu_custom_call_f32);
193
- XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f64" , eigh_cpu_custom_call_f64);
193
+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f64" , eigh_cpu_custom_call_f64);
0 commit comments