Skip to content

Commit 2502181

Browse files
Pass custom call sizes as unsigned integers (#1526)
1 parent 3ef1c7a commit 2502181

File tree

2 files changed

+54
-54
lines changed

2 files changed

+54
-54
lines changed

exla/c_src/exla/custom_calls.cc

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "xla/service/custom_call_target_registry.h"
88

99
template <typename DataType>
10-
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) {
10+
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
1111
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
1212

1313
// Map the input matrix
@@ -33,7 +33,7 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
3333
}
3434

3535
template <typename DataType>
36-
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
36+
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) {
3737
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
3838

3939
Eigen::Map<RowMajorMatrix> input(in, m, n);
@@ -48,8 +48,8 @@ void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType
4848

4949
num_bytes_q = m * m * sizeof(DataType);
5050

51-
for (int64_t i = 0; i < m; ++i) {
52-
for (int64_t j = 0; j < n; ++j) {
51+
for (uint64_t i = 0; i < m; ++i) {
52+
for (uint64_t j = 0; j < n; ++j) {
5353
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
5454
}
5555
}
@@ -59,8 +59,8 @@ void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType
5959

6060
num_bytes_q = m * k * sizeof(DataType);
6161

62-
for (int64_t i = 0; i < k; ++i) {
63-
for (int64_t j = 0; j < n; ++j) {
62+
for (uint64_t i = 0; i < k; ++i) {
63+
for (uint64_t j = 0; j < n; ++j) {
6464
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
6565
}
6666
}
@@ -73,40 +73,40 @@ template <typename DataType>
7373
void qr_cpu_custom_call(void *out[], const void *in[]) {
7474
DataType *operand = (DataType *)in[0];
7575

76-
int64_t *dim_sizes = (int64_t *)in[1];
77-
int64_t num_operand_dims = dim_sizes[0];
78-
int64_t num_q_dims = dim_sizes[1];
79-
int64_t num_r_dims = dim_sizes[2];
76+
uint64_t *dim_sizes = (uint64_t *)in[1];
77+
uint64_t num_operand_dims = dim_sizes[0];
78+
uint64_t num_q_dims = dim_sizes[1];
79+
uint64_t num_r_dims = dim_sizes[2];
8080

81-
int64_t *operand_dims_ptr = (int64_t *)in[2];
82-
std::vector<int64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
81+
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
82+
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
8383

84-
int64_t *q_dims_ptr = (int64_t *)in[3];
85-
std::vector<int64_t> q_dims(q_dims_ptr, q_dims_ptr + num_q_dims);
84+
uint64_t *q_dims_ptr = (uint64_t *)in[3];
85+
std::vector<uint64_t> q_dims(q_dims_ptr, q_dims_ptr + num_q_dims);
8686

87-
int64_t *r_dims_ptr = (int64_t *)in[4];
88-
std::vector<int64_t> r_dims(r_dims_ptr, r_dims_ptr + num_r_dims);
87+
uint64_t *r_dims_ptr = (uint64_t *)in[4];
88+
std::vector<uint64_t> r_dims(r_dims_ptr, r_dims_ptr + num_r_dims);
8989

90-
int64_t m = q_dims[q_dims.size() - 2];
91-
int64_t k = q_dims[q_dims.size() - 1];
92-
int64_t n = r_dims[r_dims.size() - 1];
90+
uint64_t m = q_dims[q_dims.size() - 2];
91+
uint64_t k = q_dims[q_dims.size() - 1];
92+
uint64_t n = r_dims[r_dims.size() - 1];
9393
bool complete = r_dims[r_dims.size() - 2] == m;
9494

95-
auto leading_dimensions = std::vector<int64_t>(operand_dims.begin(), operand_dims.end() - 2);
95+
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
9696

97-
int64_t batch_items = 1;
98-
for (int64_t i = 0; i < leading_dimensions.size(); i++) {
97+
uint64_t batch_items = 1;
98+
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
9999
batch_items *= leading_dimensions[i];
100100
}
101101

102102
DataType *q = (DataType *)out[0];
103103
DataType *r = (DataType *)out[1];
104104

105-
int64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
106-
int64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
107-
int64_t inner_stride = m * n * sizeof(DataType);
105+
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
106+
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
107+
uint64_t inner_stride = m * n * sizeof(DataType);
108108

109-
for (int64_t i = 0; i < batch_items; i++) {
109+
for (uint64_t i = 0; i < batch_items; i++) {
110110
single_matrix_qr_cpu_custom_call<DataType>(
111111
(DataType *)out[0] + i * q_stride,
112112
(DataType *)out[1] + i * r_stride,
@@ -119,38 +119,38 @@ template <typename DataType>
119119
void eigh_cpu_custom_call(void *out[], const void *in[]) {
120120
DataType *operand = (DataType *)in[0];
121121

122-
int64_t *dim_sizes = (int64_t *)in[1];
123-
int64_t num_operand_dims = dim_sizes[0];
124-
int64_t num_eigenvalues_dims = dim_sizes[1];
125-
int64_t num_eigenvectors_dims = dim_sizes[2];
122+
uint64_t *dim_sizes = (uint64_t *)in[1];
123+
uint64_t num_operand_dims = dim_sizes[0];
124+
uint64_t num_eigenvalues_dims = dim_sizes[1];
125+
uint64_t num_eigenvectors_dims = dim_sizes[2];
126126

127-
int64_t *operand_dims_ptr = (int64_t *)in[2];
128-
std::vector<int64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
127+
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
128+
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
129129

130-
int64_t *eigenvalues_dims_ptr = (int64_t *)in[3];
131-
std::vector<int64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
130+
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
131+
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
132132

133-
int64_t *eigenvectors_dims_ptr = (int64_t *)in[4];
134-
std::vector<int64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
133+
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
134+
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
135135

136-
int64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
137-
int64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
136+
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
137+
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
138138

139-
auto leading_dimensions = std::vector<int64_t>(operand_dims.begin(), operand_dims.end() - 2);
139+
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
140140

141-
int64_t batch_items = 1;
142-
for (int64_t i = 0; i < leading_dimensions.size(); i++) {
141+
uint64_t batch_items = 1;
142+
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
143143
batch_items *= leading_dimensions[i];
144144
}
145145

146146
DataType *eigenvalues = (DataType *)out[0];
147147
DataType *eigenvectors = (DataType *)out[1];
148148

149-
int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
150-
int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
151-
int64_t inner_stride = m * n * sizeof(DataType);
149+
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
150+
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
151+
uint64_t inner_stride = m * n * sizeof(DataType);
152152

153-
for (int64_t i = 0; i < batch_items; i++) {
153+
for (uint64_t i = 0; i < batch_items; i++) {
154154
single_matrix_eigh_cpu_custom_call<DataType>(
155155
eigenvalues + i * eigenvalues_stride,
156156
eigenvectors + i * eigenvectors_stride,
@@ -190,4 +190,4 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_c
190190

191191

192192
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
193-
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
193+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);

exla/lib/exla/mlir/value.ex

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -720,14 +720,14 @@ defmodule EXLA.MLIR.Value do
720720
eigenvecs_dims = Tuple.to_list(eigenvecs_shape)
721721
eigenvals_dims = Tuple.to_list(eigenvals_shape)
722722

723-
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)}))
724-
operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)}))
723+
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
724+
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
725725

726726
eigenvecs_dims =
727-
constant(func, eigenvecs_dims, Typespec.tensor({:s, 64}, {length(eigenvecs_dims)}))
727+
constant(func, eigenvecs_dims, Typespec.tensor({:u, 64}, {length(eigenvecs_dims)}))
728728

729729
eigenvals_dims =
730-
constant(func, eigenvals_dims, Typespec.tensor({:s, 64}, {length(eigenvals_dims)}))
730+
constant(func, eigenvals_dims, Typespec.tensor({:u, 64}, {length(eigenvals_dims)}))
731731

732732
operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims]
733733

@@ -772,10 +772,10 @@ defmodule EXLA.MLIR.Value do
772772
q_dims = Tuple.to_list(q_shape)
773773
r_dims = Tuple.to_list(r_shape)
774774

775-
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)}))
776-
operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)}))
777-
q_dims = constant(func, q_dims, Typespec.tensor({:s, 64}, {length(q_dims)}))
778-
r_dims = constant(func, r_dims, Typespec.tensor({:s, 64}, {length(r_dims)}))
775+
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
776+
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
777+
q_dims = constant(func, q_dims, Typespec.tensor({:u, 64}, {length(q_dims)}))
778+
r_dims = constant(func, r_dims, Typespec.tensor({:u, 64}, {length(r_dims)}))
779779
operands = [value, dim_sizes, operand_dims, q_dims, r_dims]
780780

781781
q_result_type = type_tensor(q_type, q_shape)

0 commit comments

Comments
 (0)