Skip to content

Commit 93e4383

Browse files
authored
feat(exla): Host IPC and revamped pointer representation (#1531)
1 parent 116b124 commit 93e4383

File tree

10 files changed

+260
-76
lines changed

10 files changed

+260
-76
lines changed

exla/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
6161
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
6262
fi
6363

64-
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
65-
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
64+
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc
65+
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
6666
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
6767

6868

exla/c_src/exla/exla.cc

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
#include <sstream>
12
#include <string>
23

34
#include "exla_client.h"
45
#include "exla_cuda.h"
56
#include "exla_log_sink.h"
67
#include "exla_mlir.h"
78
#include "exla_nif_util.h"
9+
#include "ipc.h"
810
#include "mhlo/IR/hlo_ops.h"
911
#include "mlir/Dialect/Func/IR/FuncOps.h"
1012
#include "stablehlo/dialect/ChloOps.h"
@@ -449,34 +451,60 @@ ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_T
449451
return exla::nif::error(env, "Unable to get device pointer kind.");
450452
}
451453

454+
EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);
455+
452456
EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr,
453457
(*buffer)->GetDevicePointer((*client)->client()), env);
454458

455-
std::vector<unsigned char> pointer_vec;
459+
ERL_NIF_TERM out_term;
456460
if (pointer_kind == "local") {
457-
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
458-
for (size_t i = 0; i < sizeof(void*); i++) {
459-
pointer_vec.push_back(bytePtr[i]);
461+
ERL_NIF_TERM ptr_term = enif_make_ulong(env, ptr);
462+
ERL_NIF_TERM size_term = enif_make_ulong(env, device_size);
463+
out_term = enif_make_tuple2(env, ptr_term, size_term);
464+
} else if (pointer_kind == "host_ipc") {
465+
std::ostringstream handle_name_stream;
466+
handle_name_stream << "exla:ipc:" << device_size << ":" << ptr;
467+
std::string handle_name = handle_name_stream.str();
468+
int fd = get_ipc_handle((char*)handle_name.c_str(), device_size);
469+
470+
if (fd == -1) {
471+
return exla::nif::error(env, "Unable to get IPC handle");
472+
}
473+
474+
void* ipc_ptr = open_ipc_handle(fd, device_size);
475+
if (ipc_ptr == nullptr) {
476+
return exla::nif::error(env, "Unable to open IPC handle");
460477
}
478+
479+
memcpy(ipc_ptr, (void*)ptr, device_size);
480+
481+
ErlNifBinary handle_name_bin;
482+
enif_alloc_binary(handle_name.size(), &handle_name_bin);
483+
for (int i = 0; i < handle_name.size(); i++) {
484+
handle_name_bin.data[i] = handle_name[i];
485+
}
486+
ERL_NIF_TERM handle_name_term = enif_make_binary(env, &handle_name_bin);
487+
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
488+
ERL_NIF_TERM fd_term = enif_make_int(env, fd);
489+
out_term = enif_make_tuple3(env, handle_name_term, fd_term, size_term);
461490
} else if (pointer_kind == "cuda_ipc") {
462491
auto result = get_cuda_ipc_handle(ptr);
463492
if (result.second) {
464493
return exla::nif::error(env, "Unable to get cuda IPC handle");
465494
}
466-
pointer_vec = result.first;
467-
}
495+
auto pointer_vec = result.first;
468496

469-
EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);
470-
471-
ERL_NIF_TERM handle_list[pointer_vec.size()];
472-
for (int i = 0; i < pointer_vec.size(); i++) {
473-
handle_list[i] = enif_make_uint(env, pointer_vec[i]);
497+
ErlNifBinary handle_bin;
498+
enif_alloc_binary(pointer_vec.size(), &handle_bin);
499+
for (int i = 0; i < pointer_vec.size(); i++) {
500+
handle_bin.data[i] = pointer_vec[i];
501+
}
502+
ERL_NIF_TERM handle_term = enif_make_binary(env, &handle_bin);
503+
ERL_NIF_TERM size_term = enif_make_uint64(env, device_size);
504+
out_term = enif_make_tuple2(env, handle_term, size_term);
474505
}
475506

476-
ERL_NIF_TERM handle_list_term = enif_make_list_from_array(env, handle_list, pointer_vec.size());
477-
ERL_NIF_TERM device_size_term = enif_make_uint64(env, device_size);
478-
479-
return exla::nif::ok(env, enif_make_tuple2(env, handle_list_term, device_size_term));
507+
return exla::nif::ok(env, out_term);
480508
}
481509

482510
ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
@@ -485,40 +513,68 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
485513
}
486514

487515
exla::ExlaClient** client;
488-
std::vector<int64_t> pointer_vec;
516+
ErlNifBinary cuda_ipc_handle_bin;
517+
int cuda_ipc_handle_size = 0;
489518
xla::Shape shape;
490519
int device_id;
491520
std::string pointer_kind;
521+
void* ptr;
522+
int fd = -1;
523+
std::string memname;
492524

493525
if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
494526
return exla::nif::error(env, "Unable to get client.");
495527
}
496-
if (!exla::nif::get_list(env, argv[1], pointer_vec)) {
497-
return exla::nif::error(env, "Unable to get device pointer.");
498-
}
499-
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
528+
if (!exla::nif::get_atom(env, argv[1], pointer_kind)) {
500529
return exla::nif::error(env, "Unable to get device pointer kind.");
501530
}
531+
532+
if (pointer_kind == "cuda_ipc") {
533+
if (!enif_inspect_binary(env, argv[2], &cuda_ipc_handle_bin)) {
534+
return exla::nif::error(env, "Unable to get CUDA IPC handle.");
535+
}
536+
} else if (pointer_kind == "host_ipc") {
537+
const ERL_NIF_TERM* tuple;
538+
int arity;
539+
if (
540+
!enif_get_tuple(env, argv[2], &arity, &tuple) ||
541+
(arity != 2) ||
542+
!exla::nif::get(env, tuple[0], &fd) ||
543+
(fd == -1) ||
544+
!exla::nif::get(env, tuple[1], memname)) {
545+
return exla::nif::error(env, "Unable to get IPC handle.");
546+
}
547+
} else if (pointer_kind == "local") {
548+
int64_t ptr_int;
549+
if (!exla::nif::get(env, argv[2], &ptr_int)) {
550+
return exla::nif::error(env, "Unable to get pointer.");
551+
}
552+
553+
ptr = (void*)ptr_int;
554+
}
555+
502556
if (!exla::nif::get_typespec_as_xla_shape(env, argv[3], &shape)) {
503557
return exla::nif::error(env, "Unable to get shape.");
504558
}
505559
if (!exla::nif::get(env, argv[4], &device_id)) {
506560
return exla::nif::error(env, "Unable to get device ordinal.");
507561
}
508562

509-
void* ptr;
510-
if (pointer_kind == "local") {
511-
if (pointer_vec.size() != sizeof(void*)) {
512-
// This helps prevent segfaults if someone passes an IPC handle instead of
513-
// a local pointer.
514-
return exla::nif::error(env, "Invalid pointer size for selected mode.");
515-
}
516-
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
517-
for (size_t i = 0; i < sizeof(void*); i++) {
518-
bytePtr[i] = pointer_vec[i];
563+
std::function<void()> on_delete_callback = []() {};
564+
565+
if (pointer_kind == "host_ipc") {
566+
size_t device_size = (size_t)xla::ShapeUtil::ByteSizeOf(shape);
567+
568+
ptr = open_ipc_handle(fd, device_size);
569+
if (ptr == nullptr) {
570+
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
519571
}
572+
573+
on_delete_callback = [fd, memname, ptr, device_size]() {
574+
close_ipc_handle(fd, ptr, (char*)memname.c_str(), device_size);
575+
};
520576
} else if (pointer_kind == "cuda_ipc") {
521-
auto result = get_pointer_for_ipc_handle(pointer_vec, device_id);
577+
auto result = get_pointer_for_ipc_handle(cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id);
522578
if (result.second) {
523579
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
524580
}
@@ -527,8 +583,8 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
527583

528584
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);
529585

530-
std::function<void()> on_delete_callback = []() {};
531586
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env);
587+
532588
exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer));
533589
return exla::nif::ok(env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
534590
}

exla/c_src/exla/exla_cuda.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
2020
return std::make_pair(result, status != cudaSuccess);
2121
}
2222

23-
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
24-
if (handle_list.size() != sizeof(cudaIpcMemHandle_t)) {
25-
printf("Error: Invalid CUDA IPC memory handle size\n");
23+
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
24+
if (handle_size != sizeof(cudaIpcMemHandle_t)) {
2625
return std::make_pair(nullptr, 1); // Return with error status
2726
}
2827

2928
unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
3029
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
31-
ipc_handle_data[i] = (uint8_t)handle_list[i];
30+
ipc_handle_data[i] = handle_bin[i];
3231
}
3332

3433
cudaIpcMemHandle_t ipc_handle;
@@ -54,7 +53,7 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
5453
return std::make_pair(std::vector<unsigned char>(0), 1);
5554
}
5655

57-
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
56+
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) {
5857
return std::make_pair(nullptr, 1);
5958
}
6059
#endif

exla/c_src/exla/exla_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#pragma once
22

3+
#include <cstddef>
34
#include <cstdint>
45
#include <vector>
56

67
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
7-
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>, int);
8+
std::pair<void*, int> get_pointer_for_ipc_handle(uint8_t*, size_t, int);

exla/c_src/exla/ipc.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "ipc.h"
2+
3+
#include <fcntl.h>
4+
#include <sys/mman.h>
5+
#include <sys/stat.h>
6+
#include <unistd.h>
7+
8+
#include <iostream>
9+
10+
// Function to create or open a shared memory object and set its size
11+
int get_ipc_handle(const char* memname, size_t memsize) {
12+
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);
13+
if (fd == -1) {
14+
return -1;
15+
}
16+
17+
if (ftruncate(fd, memsize) == -1) {
18+
close(fd);
19+
return -1;
20+
}
21+
22+
return fd;
23+
}
24+
25+
// Function to map the shared memory in this process
26+
void* open_ipc_handle(int fd, size_t memsize) {
27+
void* ptr = mmap(NULL, memsize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
28+
if (ptr == MAP_FAILED) {
29+
perror("mmap");
30+
return nullptr;
31+
}
32+
return ptr;
33+
}
34+
35+
int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize) {
36+
if (munmap(ptr, memsize) == -1) {
37+
return -1;
38+
}
39+
40+
if (close(fd) == -1) {
41+
return -1;
42+
}
43+
44+
shm_unlink(memname);
45+
46+
return 0;
47+
}

exla/c_src/exla/ipc.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
5+
int get_ipc_handle(const char* memname, size_t memsize);
6+
void* open_ipc_handle(int fd, size_t memsize);
7+
int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize);

0 commit comments

Comments
 (0)