Skip to content

Commit c6f8cec

Browse files
authored
feat: add to/from pointer (#1473)
1 parent ac771f5 commit c6f8cec

File tree

18 files changed

+409
-10
lines changed

18 files changed

+409
-10
lines changed

exla/Makefile

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compa
2929
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
3030
-std=c++17 -w -DLLVM_VERSION_STRING=
3131

32+
NVCCFLAGS = -shared -Xcompiler -fPIC
33+
3234
ifdef DEBUG
3335
CFLAGS += -g
36+
NVCCFLAGS += -g
3437
else
3538
CFLAGS += -O3
3639
endif
@@ -60,7 +63,23 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
6063

6164
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
6265
HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
63-
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES))
66+
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
67+
68+
69+
NVCC_RESULT := $(shell which nvcc 2> NULL)
70+
NVCC_TEST := $(notdir $(NVCC_RESULT))
71+
72+
ifeq ($(NVCC_TEST),nvcc)
73+
NVCC := nvcc
74+
NVCCFLAGS += -DCUDA_ENABLED
75+
else
76+
NVCC := g++
77+
NVCCFLAGS = $(CFLAGS)
78+
endif
79+
80+
$(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cuda.h
81+
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
82+
$(NVCC) $(NVCCFLAGS) -c $< -o $@
6483

6584
$(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS)
6685
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)

exla/c_src/exla/exla.cc

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <string>
55

66
#include "exla_client.h"
7+
#include "exla_cuda.h"
78
#include "exla_log_sink.h"
89
#include "exla_nif_util.h"
910
#include "mlir/ops.h"
@@ -134,6 +135,104 @@ ERL_NIF_TERM create_sub_builder(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg
134135

135136
// ExlaBuffer Functions
136137

138+
ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
139+
if (argc != 3) {
140+
return exla::nif::error(env, "Bad argument count.");
141+
}
142+
143+
exla::ExlaClient** client;
144+
exla::ExlaBuffer** buffer;
145+
std::string pointer_kind;
146+
147+
if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
148+
return exla::nif::error(env, "Unable to get client.");
149+
}
150+
if (!exla::nif::get<exla::ExlaBuffer*>(env, argv[1], buffer)) {
151+
return exla::nif::error(env, "Unable to get buffer.");
152+
}
153+
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
154+
return exla::nif::error(env, "Unable to get device pointer kind.");
155+
}
156+
157+
EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr,
158+
(*buffer)->GetDevicePointer((*client)->client()), env);
159+
160+
std::vector<unsigned char> pointer_vec;
161+
if (pointer_kind == "local") {
162+
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
163+
for (size_t i = 0; i < sizeof(void*); i++) {
164+
pointer_vec.push_back(bytePtr[i]);
165+
}
166+
} else if (pointer_kind == "cuda_ipc") {
167+
auto result = get_cuda_ipc_handle(ptr);
168+
if (result.second) {
169+
return exla::nif::error(env, "Unable to get cuda IPC handle");
170+
}
171+
pointer_vec = result.first;
172+
}
173+
174+
EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);
175+
176+
ERL_NIF_TERM handle_list[pointer_vec.size()];
177+
for (int i = 0; i < pointer_vec.size(); i++) {
178+
handle_list[i] = enif_make_uint(env, pointer_vec[i]);
179+
}
180+
181+
ERL_NIF_TERM handle_list_term = enif_make_list_from_array(env, handle_list, pointer_vec.size());
182+
ERL_NIF_TERM device_size_term = enif_make_uint64(env, device_size);
183+
184+
return exla::nif::ok(env, enif_make_tuple2(env, handle_list_term, device_size_term));
185+
}
186+
187+
ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
188+
if (argc != 5) {
189+
return exla::nif::error(env, "Bad argument count.");
190+
}
191+
192+
exla::ExlaClient** client;
193+
std::vector<int64_t> pointer_vec;
194+
xla::Shape* shape;
195+
int device_id;
196+
std::string pointer_kind;
197+
198+
if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
199+
return exla::nif::error(env, "Unable to get client.");
200+
}
201+
if (!exla::nif::get_list(env, argv[1], pointer_vec)) {
202+
return exla::nif::error(env, "Unable to get device pointer.");
203+
}
204+
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
205+
return exla::nif::error(env, "Unable to get device pointer kind.");
206+
}
207+
if (!exla::nif::get<xla::Shape>(env, argv[3], shape)) {
208+
return exla::nif::error(env, "Unable to get shape.");
209+
}
210+
if (!exla::nif::get(env, argv[4], &device_id)) {
211+
return exla::nif::error(env, "Unable to get device ordinal.");
212+
}
213+
214+
void* ptr;
215+
if (pointer_kind == "local") {
216+
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
217+
for (size_t i = 0; i < sizeof(void*); i++) {
218+
bytePtr[i] = pointer_vec[i];
219+
}
220+
} else if (pointer_kind == "cuda_ipc") {
221+
auto result = get_pointer_for_ipc_handle(pointer_vec);
222+
if (result.second) {
223+
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
224+
}
225+
ptr = result.first;
226+
}
227+
228+
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(device_id), env);
229+
230+
std::function<void()> on_delete_callback = []() {};
231+
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, *shape, device, on_delete_callback), env);
232+
exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer));
233+
return exla::nif::ok(env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
234+
}
235+
137236
ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
138237
if (argc != 4) {
139238
return exla::nif::error(env, "Bad argument count.");
@@ -710,6 +809,8 @@ static ErlNifFunc exla_funcs[] = {
710809
{"get_supported_platforms", 0, get_supported_platforms},
711810
{"mlir_compile", 7, mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
712811
// ExlaBuffer
812+
{"get_buffer_device_pointer", 3, get_buffer_device_pointer},
813+
{"create_buffer_from_device_pointer", 5, create_buffer_from_device_pointer},
713814
{"binary_to_device_mem", 4, binary_to_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},
714815
{"read_device_mem", 2, read_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},
715816
{"deallocate_device_mem", 1, deallocate_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},

exla/c_src/exla/exla_client.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class ExlaBuffer {
3131
xla::StatusOr<ERL_NIF_TERM> ToBinary(ErlNifEnv* env, exla::int64 size);
3232
xla::Status Deallocate();
3333

34+
xla::StatusOr<std::uintptr_t> GetDevicePointer(xla::PjRtClient* client) {
35+
return client->UnsafeBufferPointer(buffer_.get());
36+
}
37+
38+
xla::StatusOr<size_t> GetOnDeviceSizeInBytes() {
39+
return buffer_.get()->GetOnDeviceSizeInBytes();
40+
}
41+
3442
~ExlaBuffer() {
3543
// Theoretically this may block if a computation is running
3644
// but we always block the host until the computation is done.

exla/c_src/exla/exla_cuda.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "exla_cuda.h"
2+
3+
#ifdef CUDA_ENABLED
4+
#include <cuda_runtime.h>
5+
6+
#include <cstring>
7+
#include <iostream>
8+
9+
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t ptr) {
10+
cudaIpcMemHandle_t ipc_handle;
11+
cudaError_t status = cudaIpcGetMemHandle(&ipc_handle, reinterpret_cast<void*>(ptr));
12+
13+
// Assuming sizeof(cudaIpcMemHandle_t) is constant
14+
const size_t size = sizeof(cudaIpcMemHandle_t);
15+
16+
// Copy the memory handle to a byte array
17+
std::vector<unsigned char> result(size);
18+
memcpy(result.data(), &ipc_handle, size);
19+
20+
return std::make_pair(result, status != cudaSuccess);
21+
}
22+
23+
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list) {
24+
unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
25+
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
26+
ipc_handle_data[i] = (uint8_t)handle_list[i];
27+
}
28+
29+
cudaIpcMemHandle_t ipc_handle;
30+
memcpy(&ipc_handle, ipc_handle_data, sizeof(cudaIpcMemHandle_t));
31+
32+
int* ptr;
33+
cudaError_t cuda_status = cudaSetDevice(0); // Assuming device 0, change as needed
34+
if (cuda_status != cudaSuccess) {
35+
printf("Error setting CUDA device: %s\n", cudaGetErrorString(cuda_status));
36+
return std::make_pair(nullptr, 1); // Return with error status
37+
}
38+
39+
cuda_status = cudaIpcOpenMemHandle((void**)&ptr, ipc_handle, cudaIpcMemLazyEnablePeerAccess);
40+
if (cuda_status != cudaSuccess) {
41+
printf("Error opening CUDA IPC memory handle: %s\n", cudaGetErrorString(cuda_status));
42+
return std::make_pair(nullptr, 1); // Return with error status
43+
}
44+
45+
return std::make_pair(ptr, cuda_status != cudaSuccess);
46+
}
47+
#else
48+
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t ptr) {
49+
return std::make_pair(std::vector<unsigned char>(0), 1);
50+
}
51+
52+
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list) {
53+
return std::make_pair(nullptr, 1);
54+
}
55+
#endif

exla/c_src/exla/exla_cuda.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <vector>
5+
6+
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
7+
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>);

exla/c_src/exla/exla_nif_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ ERL_NIF_TERM make(ErlNifEnv* env, const char* string);
117117
// their signatures are the same for retrieving/returning
118118
// regular strings.
119119

120-
int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string* var);
120+
int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var);
121121

122122
ERL_NIF_TERM atom(ErlNifEnv* env, const char* status);
123123

exla/lib/exla/backend.ex

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,71 @@ defmodule EXLA.Backend do
8282
EXLA.DeviceBuffer.deallocate(buffer)
8383
end
8484

85+
@impl true
86+
def to_pointer(%T{data: %B{buffer: buffer}}, opts \\ []) do
87+
opts = Keyword.validate!(opts, mode: :local)
88+
89+
mode =
90+
case opts[:mode] do
91+
mode when mode in [:local, :cuda_ipc] ->
92+
mode
93+
94+
mode ->
95+
raise ArgumentError, "expected one of :local, :cuda_ipc, got: #{inspect(mode)}"
96+
end
97+
98+
case buffer do
99+
%EXLA.DeviceBuffer{} ->
100+
:ok
101+
102+
_ ->
103+
raise ArgumentError, "tensor must be allocated via a #{DeviceBuffer}"
104+
end
105+
106+
client = EXLA.Client.fetch!(buffer.client_name)
107+
108+
case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
109+
{:ok, {pointer, _size}} ->
110+
{:ok, pointer}
111+
112+
error ->
113+
error
114+
end
115+
end
116+
117+
@impl true
118+
def from_pointer(pointer, type, dims, backend_opts, opts) do
119+
backend_opts = Keyword.validate!(backend_opts, [:client_name, :device_id])
120+
opts = Keyword.validate!(opts, [:names, mode: :local])
121+
122+
template = Nx.template(dims, type, names: opts[:names])
123+
124+
client_name = backend_opts[:client_name] || EXLA.Client.default_name()
125+
client = EXLA.Client.fetch!(client_name)
126+
127+
device_id = backend_opts[:device_id] || client.default_device_id
128+
129+
shape = EXLA.Shape.make_shape(type, dims)
130+
131+
result =
132+
EXLA.NIF.create_buffer_from_device_pointer(
133+
client.ref,
134+
pointer,
135+
opts[:mode],
136+
shape.ref,
137+
device_id
138+
)
139+
140+
case result do
141+
{:ok, ref} ->
142+
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, shape)
143+
{:ok, %{template | data: %EXLA.Backend{buffer: buffer}}}
144+
145+
error ->
146+
error
147+
end
148+
end
149+
85150
@impl true
86151
def to_batched(out, tensor, opts) do
87152
leftover = opts[:leftover]

exla/lib/exla/nif.ex

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,17 @@ defmodule EXLA.NIF do
229229
),
230230
do: :erlang.nif_error(:undef)
231231

232+
def get_buffer_device_pointer(_client, _buffer, _pointer_kind), do: :erlang.nif_error(:undef)
233+
234+
def create_buffer_from_device_pointer(
235+
_client,
236+
_opaque_pointer,
237+
_pointer_kind,
238+
_shape,
239+
_device_id
240+
),
241+
do: :erlang.nif_error(:undef)
242+
232243
def binary_to_device_mem(_client, _binary, _shape, _device_ordinal),
233244
do: :erlang.nif_error(:undef)
234245

exla/mix.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ defmodule EXLA.MixProject do
5959

6060
defp deps do
6161
[
62-
{:nx, "~> 0.7.1"},
63-
# {:nx, path: "../nx"},
62+
# {:nx, "~> 0.7.1"},
63+
{:nx, path: "../nx"},
6464
{:telemetry, "~> 0.4.0 or ~> 1.0"},
6565
{:xla, "~> 0.6.0", runtime: false},
6666
{:elixir_make, "~> 0.6", runtime: false},
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
defmodule EXLA.DeviceMemorySharingTest do
2+
use EXLA.Case, async: false
3+
4+
for client_name <- [:host, :cuda] do
5+
if client_name == :cuda do
6+
@tag :cuda_required
7+
end
8+
9+
test "buffer sharing on #{inspect(client_name)} works as expected" do
10+
t1 = Nx.tensor([1, 2, 3], backend: {EXLA.Backend, client: unquote(client_name)})
11+
12+
assert inspect(t1) =~ "1, 2, 3"
13+
14+
assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local)
15+
16+
assert {:ok, t2} =
17+
Nx.from_pointer(
18+
{EXLA.Backend, client_name: unquote(client_name)},
19+
pointer,
20+
t1.type,
21+
t1.shape
22+
)
23+
24+
assert t1.data.buffer.ref != t2.data.buffer.ref
25+
assert Nx.to_binary(t1) == Nx.to_binary(t2)
26+
end
27+
end
28+
end

0 commit comments

Comments
 (0)