Skip to content

Commit 7335990

Browse files
authored
Merge branch 'elixir-nx:main' into docs_refact
2 parents 22aab8f + 1980af9 commit 7335990

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2703
-2367
lines changed

exla/Makefile

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LI
2222
EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
2323

2424
# Build flags
25-
# c++17 is needed, otherwise xla headers
26-
# break on some conflicting llvm/std definitions
27-
# Note: this is on :xla 0.5.0 -- things can change with later versions
28-
CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \
25+
# Note that XLA requires c++17, Fine as well
26+
CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(FINE_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \
2927
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
3028
-std=c++17 -w -DLLVM_VERSION_STRING=
3129

@@ -38,27 +36,29 @@ else
3836
CFLAGS += -O3
3937
endif
4038

41-
LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared
39+
NVCC := $(CXX)
40+
NVCCFLAGS = $(CFLAGS)
41+
LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden
4242

4343
ifeq ($(CROSSCOMPILE),)
4444
# Interrogate the system for local compilation
4545
UNAME_S = $(shell uname -s)
4646

47+
ifdef ($(EXLA_CPU_ONLY),)
48+
$(info EXLA_CPU_ONLY is not set, checking for nvcc availability)
4749
NVCC_RESULT := $(shell which nvcc 2> /dev/null)
4850
NVCC_TEST := $(notdir $(NVCC_RESULT))
4951

50-
ifeq ($(NVCC_TEST),nvcc)
51-
NVCC := nvcc
52-
NVCCFLAGS += -DCUDA_ENABLED
52+
ifeq ($(NVCC_TEST),nvcc)
53+
NVCC := nvcc
54+
NVCCFLAGS += -DCUDA_ENABLED
55+
endif
5356
else
54-
NVCC := $(CXX)
55-
NVCCFLAGS = $(CFLAGS)
57+
$(info EXLA_CPU_ONLY is set, skipping nvcc step)
5658
endif
5759
else
5860
# Determine settings for cross-compiled builds like for Nerves
5961
UNAME_S = Linux
60-
NVCC := $(CXX)
61-
NVCCFLAGS = $(CFLAGS)
6262
endif
6363

6464
ifeq ($(UNAME_S), Darwin)
@@ -82,7 +82,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
8282
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
8383
fi
8484

85-
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
85+
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/ipc.cc
8686
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
8787
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
8888
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o

exla/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ EXLA relies on the [XLA](https://github.com/elixir-nx/xla) package to provide th
4343
* Incompatible protocol buffer versions
4444
* Error message: "this file was generated by an older version of protoc which is incompatible with your Protocol Buffer headers".
4545
* If you have `protoc` installed on your machine, it may conflict with the `protoc` precompiled inside XLA. Uninstall, unlink, or remove `protoc` from your path to continue.
46+
* Missing CUDA symbols
47+
* In some cases, you might be compiling a CPU-only version of `:xla` in an environment that has CUDA available. For these cases, you can set the `EXLA_CPU_ONLY` environment variable to any value to disable custom CUDA functionality in EXLA.
4648

4749
### Usage with Nerves
4850

exla/c_src/exla/custom_calls/eigh.h

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22

33
#include "Eigen/Eigenvalues"
44

5+
#include <algorithm>
56
#include <iostream>
7+
#include <numeric>
8+
#include <vector>
69

710
template <typename DataType>
8-
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
9-
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
11+
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
12+
DataType *eigenvectors_out,
13+
DataType *in, uint64_t m, uint64_t n) {
14+
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
15+
Eigen::RowMajor>
16+
RowMajorMatrix;
1017

1118
// Map the input matrix
1219
Eigen::Map<RowMajorMatrix> input(in, m, n);
@@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig
2027
}
2128

2229
// Get the eigenvalues and eigenvectors
23-
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
30+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues =
31+
eigensolver.eigenvalues();
2432
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();
2533

26-
// Copy the eigenvalues to the output
27-
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));
34+
// Create a vector of indices and sort it based on eigenvalues in decreasing
35+
// order
36+
std::vector<int> indices(m);
37+
std::iota(indices.begin(), indices.end(), 0);
38+
std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) {
39+
return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j));
40+
});
41+
42+
// Sort eigenvalues and rearrange eigenvectors
43+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> sorted_eigenvalues(m);
44+
RowMajorMatrix sorted_eigenvectors(m, n);
45+
for (int i = 0; i < m; ++i) {
46+
sorted_eigenvalues(i) = eigenvalues(indices[i]);
47+
sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]);
48+
}
49+
50+
// Copy the sorted eigenvalues to the output
51+
std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), m * sizeof(DataType));
2852

29-
// Copy the eigenvectors to the output
30-
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
53+
// Copy the sorted eigenvectors to the output
54+
std::memcpy(eigenvectors_out, sorted_eigenvectors.data(),
55+
m * n * sizeof(DataType));
3156
}
3257

3358
template <typename DataType>
@@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
4065
uint64_t num_eigenvectors_dims = dim_sizes[2];
4166

4267
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
43-
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
68+
std::vector<uint64_t> operand_dims(operand_dims_ptr,
69+
operand_dims_ptr + num_operand_dims);
4470

4571
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
46-
std::vector<uint64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
72+
std::vector<uint64_t> eigenvalues_dims(
73+
eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
4774

4875
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
49-
std::vector<uint64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
76+
std::vector<uint64_t> eigenvectors_dims(
77+
eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
5078

5179
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
5280
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
5381

54-
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
82+
auto leading_dimensions =
83+
std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
5584

5685
uint64_t batch_items = 1;
5786
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
@@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) {
6190
DataType *eigenvalues = (DataType *)out[0];
6291
DataType *eigenvectors = (DataType *)out[1];
6392

64-
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
65-
uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
66-
uint64_t inner_stride = m * n * sizeof(DataType);
93+
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
94+
uint64_t eigenvectors_stride =
95+
eigenvectors_dims[eigenvectors_dims.size() - 1] *
96+
eigenvectors_dims[eigenvectors_dims.size() - 2];
97+
uint64_t inner_stride = m * n;
6798

6899
for (uint64_t i = 0; i < batch_items; i++) {
69100
single_matrix_eigh_cpu_custom_call<DataType>(
70101
eigenvalues + i * eigenvalues_stride,
71-
eigenvectors + i * eigenvectors_stride,
72-
operand + i * inner_stride / sizeof(DataType),
73-
m, n);
102+
eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
103+
n);
74104
}
75105
}

0 commit comments

Comments
 (0)