Skip to content

Commit d31c33e

Browse files
authored
Improve compile times for EXLA (#1539)
Authored by: Paulo Valente <[email protected]>
1 parent 6d12d3e commit d31c33e

18 files changed

+308
-347
lines changed

exla/Makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
6161
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
6262
fi
6363

64-
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
65-
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
64+
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
65+
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
66+
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
6667
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
6768

6869

@@ -83,7 +84,7 @@ $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cud
8384

8485
$(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS)
8586
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
86-
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)/mlir
87+
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)/custom_calls
8788
$(CXX) $(CFLAGS) -c $< -o $@
8889

8990
$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(OBJECTS)

exla/c_src/exla/custom_calls.cc

Lines changed: 8 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -1,193 +1,15 @@
1-
#include "custom_calls.h"
2-
3-
#include "Eigen/Dense"
4-
#include "Eigen/Eigenvalues"
5-
#include "Eigen/QR"
6-
#include "exla_nif_util.h"
71
#include "xla/service/custom_call_target_registry.h"
82

9-
template <typename DataType>
10-
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
11-
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
12-
13-
// Map the input matrix
14-
Eigen::Map<RowMajorMatrix> input(in, m, n);
15-
16-
// Compute the Eigenvalue decomposition
17-
Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver(input);
18-
19-
if (eigensolver.info() != Eigen::Success) {
20-
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
21-
return;
22-
}
23-
24-
// Get the eigenvalues and eigenvectors
25-
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
26-
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();
27-
28-
// Copy the eigenvalues to the output
29-
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));
30-
31-
// Copy the eigenvectors to the output
32-
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
33-
}
34-
35-
template <typename DataType>
36-
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) {
37-
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
38-
39-
Eigen::Map<RowMajorMatrix> input(in, m, n);
40-
Eigen::HouseholderQR<RowMajorMatrix> qr = input.householderQr();
41-
42-
RowMajorMatrix Q, R;
43-
size_t num_bytes_q, num_bytes_r;
44-
45-
if (complete) {
46-
Q = qr.householderQ() * RowMajorMatrix::Identity(m, m);
47-
R = qr.matrixQR();
48-
49-
num_bytes_q = m * m * sizeof(DataType);
50-
51-
for (uint64_t i = 0; i < m; ++i) {
52-
for (uint64_t j = 0; j < n; ++j) {
53-
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
54-
}
55-
}
56-
} else {
57-
Q = qr.householderQ() * RowMajorMatrix::Identity(m, k);
58-
R = qr.matrixQR().topRows(k);
59-
60-
num_bytes_q = m * k * sizeof(DataType);
61-
62-
for (uint64_t i = 0; i < k; ++i) {
63-
for (uint64_t j = 0; j < n; ++j) {
64-
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
65-
}
66-
}
67-
}
68-
69-
memcpy(q_out, Q.data(), num_bytes_q);
70-
}
71-
72-
template <typename DataType>
73-
void qr_cpu_custom_call(void *out[], const void *in[]) {
74-
DataType *operand = (DataType *)in[0];
75-
76-
uint64_t *dim_sizes = (uint64_t *)in[1];
77-
uint64_t num_operand_dims = dim_sizes[0];
78-
uint64_t num_q_dims = dim_sizes[1];
79-
uint64_t num_r_dims = dim_sizes[2];
80-
81-
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
82-
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
83-
84-
uint64_t *q_dims_ptr = (uint64_t *)in[3];
85-
std::vector<uint64_t> q_dims(q_dims_ptr, q_dims_ptr + num_q_dims);
86-
87-
uint64_t *r_dims_ptr = (uint64_t *)in[4];
88-
std::vector<uint64_t> r_dims(r_dims_ptr, r_dims_ptr + num_r_dims);
89-
90-
uint64_t m = q_dims[q_dims.size() - 2];
91-
uint64_t k = q_dims[q_dims.size() - 1];
92-
uint64_t n = r_dims[r_dims.size() - 1];
93-
bool complete = r_dims[r_dims.size() - 2] == m;
94-
95-
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
96-
97-
uint64_t batch_items = 1;
98-
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
99-
batch_items *= leading_dimensions[i];
100-
}
3+
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
4+
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
5+
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
6+
void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
7+
void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
8+
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);
1019

102-
DataType *q = (DataType *)out[0];
103-
DataType *r = (DataType *)out[1];
104-
105-
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
106-
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
107-
uint64_t inner_stride = m * n * sizeof(DataType);
108-
109-
for (uint64_t i = 0; i < batch_items; i++) {
110-
single_matrix_qr_cpu_custom_call<DataType>(
111-
(DataType *)out[0] + i * q_stride,
112-
(DataType *)out[1] + i * r_stride,
113-
operand + i * inner_stride * sizeof(DataType),
114-
m, k, n, complete);
115-
}
116-
}
117-
118-
template <typename DataType>
119-
void eigh_cpu_custom_call(void *out[], const void *in[]) {
120-
DataType *operand = (DataType *)in[0];
121-
122-
uint64_t *dim_sizes = (uint64_t *)in[1];
123-
uint64_t num_operand_dims = dim_sizes[0];
124-
uint64_t num_eigenvalues_dims = dim_sizes[1];
125-
uint64_t num_eigenvectors_dims = dim_sizes[2];
126-
127-
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
128-
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
129-
130-
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
131-
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
132-
133-
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
134-
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
135-
136-
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
137-
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
138-
139-
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
140-
141-
uint64_t batch_items = 1;
142-
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
143-
batch_items *= leading_dimensions[i];
144-
}
145-
146-
DataType *eigenvalues = (DataType *)out[0];
147-
DataType *eigenvectors = (DataType *)out[1];
148-
149-
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
150-
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
151-
uint64_t inner_stride = m * n * sizeof(DataType);
152-
153-
for (uint64_t i = 0; i < batch_items; i++) {
154-
single_matrix_eigh_cpu_custom_call<DataType>(
155-
eigenvalues + i * eigenvalues_stride,
156-
eigenvectors + i * eigenvectors_stride,
157-
operand + i * inner_stride / sizeof(DataType),
158-
m, n);
159-
}
160-
}
161-
162-
void qr_cpu_custom_call_bf16(void *out[], const void *in[]) {
163-
qr_cpu_custom_call<exla::bfloat16>(out, in);
164-
}
165-
166-
void qr_cpu_custom_call_f16(void *out[], const void *in[]) {
167-
qr_cpu_custom_call<exla::float16>(out, in);
168-
}
169-
170-
void qr_cpu_custom_call_f32(void *out[], const void *in[]) {
171-
qr_cpu_custom_call<float>(out, in);
172-
}
173-
174-
void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
175-
qr_cpu_custom_call<double>(out, in);
176-
}
177-
178-
void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
179-
eigh_cpu_custom_call<float>(out, in);
180-
}
181-
182-
void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
183-
eigh_cpu_custom_call<double>(out, in);
184-
}
185-
186-
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
18710
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
11+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
18812
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
18913
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
190-
191-
192-
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
19314
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
15+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);

exla/c_src/exla/custom_calls.h

Lines changed: 0 additions & 12 deletions
This file was deleted.

exla/c_src/exla/custom_calls/eigh.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#include "Eigen/Eigenvalues"
4+
5+
#include <iostream>
6+
7+
template <typename DataType>
8+
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
9+
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
10+
11+
// Map the input matrix
12+
Eigen::Map<RowMajorMatrix> input(in, m, n);
13+
14+
// Compute the Eigenvalue decomposition
15+
Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver(input);
16+
17+
if (eigensolver.info() != Eigen::Success) {
18+
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
19+
return;
20+
}
21+
22+
// Get the eigenvalues and eigenvectors
23+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
24+
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();
25+
26+
// Copy the eigenvalues to the output
27+
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));
28+
29+
// Copy the eigenvectors to the output
30+
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
31+
}
32+
33+
template <typename DataType>
34+
void eigh_cpu_custom_call(void *out[], const void *in[]) {
35+
DataType *operand = (DataType *)in[0];
36+
37+
uint64_t *dim_sizes = (uint64_t *)in[1];
38+
uint64_t num_operand_dims = dim_sizes[0];
39+
uint64_t num_eigenvalues_dims = dim_sizes[1];
40+
uint64_t num_eigenvectors_dims = dim_sizes[2];
41+
42+
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
43+
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
44+
45+
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
46+
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
47+
48+
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
49+
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
50+
51+
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
52+
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
53+
54+
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
55+
56+
uint64_t batch_items = 1;
57+
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
58+
batch_items *= leading_dimensions[i];
59+
}
60+
61+
DataType *eigenvalues = (DataType *)out[0];
62+
DataType *eigenvectors = (DataType *)out[1];
63+
64+
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
65+
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
66+
uint64_t inner_stride = m * n * sizeof(DataType);
67+
68+
for (uint64_t i = 0; i < batch_items; i++) {
69+
single_matrix_eigh_cpu_custom_call<DataType>(
70+
eigenvalues + i * eigenvalues_stride,
71+
eigenvectors + i * eigenvectors_stride,
72+
operand + i * inner_stride / sizeof(DataType),
73+
m, n);
74+
}
75+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "eigh.h"
2+
3+
void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
4+
eigh_cpu_custom_call<float>(out, in);
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "eigh.h"
2+
3+
void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
4+
eigh_cpu_custom_call<double>(out, in);
5+
}

0 commit comments

Comments
 (0)