Skip to content

Commit 03a0c1c

Browse files
Use generic API for creating MLIR operations (#1477)
Co-authored-by: Paulo Valente <[email protected]>
1 parent 9d8a1ff commit 03a0c1c

38 files changed

+2174
-5164
lines changed

exla/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
6161
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
6262
fi
6363

64-
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
65-
HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
64+
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
65+
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
6666
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
6767

6868

exla/c_src/exla/mlir/custom_calls.cc renamed to exla/c_src/exla/custom_calls.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#include "custom_calls.h"
2+
#include "exla_nif_util.h"
23

3-
#include <Eigen/Dense>
4-
#include <Eigen/QR>
4+
#include "xla/service/custom_call_target_registry.h"
55

6-
#include "builder.h"
6+
#include "Eigen/Dense"
7+
#include "Eigen/QR"
78

89
template <typename DataType>
910
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
@@ -102,4 +103,9 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]) {
102103

103104
void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
104105
qr_cpu_custom_call<double>(out, in);
105-
}
106+
}
107+
108+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
109+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
110+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
111+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#pragma once
2-
3-
template <typename DataType>
4-
void qr_cpu_custom_call(void *out[], const void *in[]);
1+
#ifndef EXLA_MLIR_CUSTOM_CALLS_H_
2+
#define EXLA_MLIR_CUSTOM_CALLS_H_
53

64
void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
75
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
86
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
9-
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
7+
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
8+
9+
#endif

0 commit comments

Comments
 (0)