Skip to content

Commit c772f3c

Browse files
committed
fix(exla): use passed device id and protect against some segfaults
1 parent 01960d9 commit c772f3c

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

exla/c_src/exla/exla.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,17 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
213213

214214
void* ptr;
215215
if (pointer_kind == "local") {
216+
if (pointer_vec.size() != sizeof(void*)) {
217+
// This helps prevent segfaults if someone passes an IPC handle instead of
218+
// a local pointer.
219+
return exla::nif::error(env, "Invalid pointer size for selected mode.");
220+
}
216221
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
217222
for (size_t i = 0; i < sizeof(void*); i++) {
218223
bytePtr[i] = pointer_vec[i];
219224
}
220225
} else if (pointer_kind == "cuda_ipc") {
221-
auto result = get_pointer_for_ipc_handle(pointer_vec);
226+
auto result = get_pointer_for_ipc_handle(pointer_vec, device_id);
222227
if (result.second) {
223228
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
224229
}

exla/c_src/exla/exla_cuda.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t pt
2020
return std::make_pair(result, status != cudaSuccess);
2121
}
2222

23-
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list) {
23+
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list, int device_id) {
24+
if (handle_list.size() != sizeof(cudaIpcMemHandle_t)) {
25+
printf("Error: Invalid CUDA IPC memory handle size\n");
26+
return std::make_pair(nullptr, 1); // Return with error status
27+
}
28+
2429
unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
2530
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
2631
ipc_handle_data[i] = (uint8_t)handle_list[i];
@@ -30,7 +35,7 @@ std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_lis
3035
memcpy(&ipc_handle, ipc_handle_data, sizeof(cudaIpcMemHandle_t));
3136

3237
int* ptr;
33-
cudaError_t cuda_status = cudaSetDevice(0); // Assuming device 0, change as needed
38+
cudaError_t cuda_status = cudaSetDevice(device_id); // Assuming device 0, change as needed
3439
if (cuda_status != cudaSuccess) {
3540
printf("Error setting CUDA device: %s\n", cudaGetErrorString(cuda_status));
3641
return std::make_pair(nullptr, 1); // Return with error status

exla/c_src/exla/exla_cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
#include <vector>
55

66
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
7-
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>);
7+
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>, int);

exla/test/exla/device_memory_sharing_test.exs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,16 @@ defmodule EXLA.DeviceMemorySharingTest do
2525
assert Nx.to_binary(t1) == Nx.to_binary(t2)
2626
end
2727
end
28+
29+
@tag :cuda_required
30+
test "ipc handles don't crash the runtime when :local mode is selected" do
31+
assert {:error, ~c"Invalid pointer size for selected mode."} ==
32+
Nx.from_pointer(
33+
{EXLA.Backend, client_name: :cuda},
34+
Enum.to_list(0..63),
35+
{:f, 32},
36+
{1},
37+
mode: :local
38+
)
39+
end
2840
end

0 commit comments

Comments
 (0)