diff --git a/src/xccl/IPCExchange.hpp b/src/xccl/IPCExchange.hpp new file mode 100644 index 0000000000..f057dbf593 --- /dev/null +++ b/src/xccl/IPCExchange.hpp @@ -0,0 +1,446 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "xccl/ze_exception.hpp" + +#include + +#include +#include +#include +#include // for std::chrono::milliseconds +#include // for std::this_thread::sleep_for + +#define ELE_COUNT 128 + +struct exchange_contents { + // first 4-byte is file descriptor for drmbuf or gem object + union { + ze_ipc_mem_handle_t ipc_handle; + int fd = -1; + }; + size_t offset = 0; + int pid = -1; +}; + +#define sysCheck(x) \ + if (x == -1) { \ + throw std::system_error(std::make_error_code(std::errc(errno))); \ + } + +// We can't inherit it from cmsghdr because flexible array member +struct exchange_fd { + char obscure[CMSG_LEN(sizeof(int)) - sizeof(int)]; + int fd; + + exchange_fd(int cmsg_level, int cmsg_type, int fd) : fd(fd) { + auto* cmsg = reinterpret_cast(obscure); + cmsg->cmsg_len = sizeof(exchange_fd); + cmsg->cmsg_level = cmsg_level; + cmsg->cmsg_type = cmsg_type; + } + + exchange_fd() : fd(-1) { + memset(obscure, 0, sizeof(obscure)); + }; +}; + +void un_send_fd(int sock, int fd, int rank, size_t offset) { + iovec iov[1]; + msghdr msg; + auto rank_offset = std::make_pair(rank, offset); + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg(SOL_SOCKET, SCM_RIGHTS, fd); + + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + sysCheck(sendmsg(sock, &msg, 0)); +} + +std::tuple un_recv_fd(int sock) { + iovec iov[1]; + msghdr msg; + std::pair rank_offset; + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg; + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + int n_recv = recvmsg(sock, &msg, 0); + sysCheck(n_recv); + // assert(n_recv == sizeof(int)); + + return std::make_tuple(cmsg.fd, rank_offset.first, rank_offset.second); +} + +int prepare_socket(const char* sockname) { + sockaddr_un un; + memset(&un, 0, sizeof(un)); + un.sun_family = AF_UNIX; + strcpy(un.sun_path, sockname); + + auto sock = socket(AF_UNIX, SOCK_STREAM, 0); + sysCheck(sock); + + int on = 1; + sysCheck(ioctl(sock, FIONBIO, &on)); + + auto size = offsetof(sockaddr_un, sun_path) + strlen(un.sun_path); + sysCheck(bind(sock, (sockaddr*)&un, size)); + + return sock; +} + +int server_listen(const char* sockname) { + unlink(sockname); + auto sock = prepare_socket(sockname); + sysCheck(listen(sock, 10)); + + return sock; +} + +int serv_accept(int listen_sock) { + sockaddr_un un; + + socklen_t len = sizeof(un); + auto accept_sock = accept(listen_sock, (sockaddr*)&un, &len); + sysCheck(accept_sock); + + return accept_sock; +} + +bool wait_for_socket_file(const char* path, int max_seconds = 10) { + struct stat buffer; + for (int i = 0; i < max_seconds * 10; ++i) { + if (stat(path, &buffer) == 0) { + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +int client_connect(const char* server, const char* client) { + if (!wait_for_socket_file(server, 10)) { + std::cerr << "Error: timeout waiting for server socket file: " << server + << std::endl; + exit(EXIT_FAILURE); + } + auto sock = prepare_socket(client); + sockaddr_un sun; + memset(&sun, 0, sizeof(sun)); + sun.sun_family = AF_UNIX; + strcpy(sun.sun_path, server); + auto len = offsetof(sockaddr_un, sun_path) + strlen(server); + const int max_retries = 50; + int retry = 0; + int ret = -1; + while (retry < max_retries) { + ret = connect(sock, (sockaddr*)&sun, len); + if (ret == 0) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + retry++; + } + if (ret != 0) { + perror("connect failed"); + exit(EXIT_FAILURE); + } + + // sysCheck(connect(sock, (sockaddr*)&sun, len)); + return sock; +} + +void un_allgather( + exchange_contents* send_buf, + exchange_contents recv_buf[], + int rank, + int world) { + const char* servername_prefix = "/tmp/open-peer-ipc-mem-server-rank_"; + const char* clientname_prefix = "/tmp/open-peer-ipc-mem-client-rank_"; + char server_name[64]; + /* get username to make server_name unique */ + auto uid = getuid(); + auto pwd = getpwuid(uid); + snprintf( + server_name, + sizeof(server_name), + "%s%d_%s", + servername_prefix, + rank, + pwd->pw_name); + unlink(server_name); + auto s_listen = server_listen(server_name); + + // MPI_Barrier(MPI_COMM_WORLD); + + pollfd fdarray[world]; + int recv_socks[world - 1]; + + for (auto& pollfd : fdarray) + pollfd.fd = -1; + std::fill(recv_socks, recv_socks + world - 1, -1); + + auto fd_guard = [&]() { + for (int i = 0, j = 0; i < world; ++i) { + if (i != rank && recv_socks[j] != -1) + sysCheck(close(recv_socks[j++])); + if (fdarray[i].fd != -1) + sysCheck(close(fdarray[i].fd)); + } + }; + + struct guard__ { + using F = decltype(fd_guard); + F f; + guard__(const F& f) : f(f) {} + ~guard__() { + f(); + } + } free_fd(fd_guard); + + // connect to all ranks + for (int i = 0; i < world; ++i) { + if (rank == i) { + fdarray[i].fd = s_listen; + fdarray[i].events = POLLIN; + fdarray[i].revents = 0; + } else { + char peer_name[64]; + char client_name[64]; + + snprintf( + client_name, + sizeof(client_name), + "%s%d-%d_%s", + clientname_prefix, + rank, + i, + pwd->pw_name); + unlink(client_name); + + snprintf( + peer_name, + sizeof(peer_name), + "%s%d_%s", + servername_prefix, + i, + pwd->pw_name); + fdarray[i].fd = client_connect(peer_name, client_name); + fdarray[i].events = POLLOUT; + fdarray[i].revents = 0; + } + } + + // std::future> future_fds[world -1]; + int slot = 0; + uint32_t send_progress = 1 << rank; + + while (slot < world - 1 || send_progress != (1 << world) - 1) { + sysCheck(ppoll(fdarray, world, nullptr, nullptr)); + + for (int i = 0; i < world; ++i) { + if (i == rank && (fdarray[i].revents & POLLIN)) { + // auto accept_sock = serv_accept(fdarray[i].fd); + // future_fds[slot ++] = std::async( + // std::launch::async, [=]() { + // struct sock_guard{ + // int sock; + // sock_guard(int sock) : sock(sock) {} + // ~guard_sock() {sysCheck(close(sock));} + // } release(accept_sock); + // auto ret = un_recv_fd(accept_sock); + // return ret;}); + recv_socks[slot++] = serv_accept(fdarray[i].fd); + } else if ( + (send_progress & (1 << i)) == 0 && fdarray[i].revents & POLLOUT) { + un_send_fd(fdarray[i].fd, send_buf->fd, rank, send_buf->offset); + send_progress |= 1 << i; + } + } + } + + for (int i = 0; i < world - 1; ++i) { + // future_fds[i].wait(); + // auto [fd, peer, offset] = future_fds[i].get(); + auto [fd, peer, offset] = un_recv_fd(recv_socks[i]); + recv_buf[peer].fd = fd; + recv_buf[peer].offset = offset; + } + + recv_buf[rank] = *send_buf; +} + +template < + typename data_type, + uint32_t max_rank = 8, + uint32_t max_buffer = 1024 /*KB*/> +class allreducer { + public: + allreducer() { + initialized = false; + size_per_buffer = 0; + buffer_index = 0; + } + allreducer(const allreducer&) = delete; + allreducer& operator=(const allreducer&) = delete; + allreducer(allreducer&& other) noexcept { + *this = std::move(other); + } + allreducer& operator=(allreducer&& other) noexcept { + if (this != &other) { + initialized = other.initialized; + rank = other.rank; + world = other.world; + std::memcpy(buffers, other.buffers, sizeof(buffers)); + std::memcpy(offsets, other.offsets, sizeof(offsets)); + std::memcpy(ipc_handle, other.ipc_handle, sizeof(ipc_handle)); + + other.initialized = false; + } + return *this; + } + ~allreducer() { + if (initialized) { + std::cerr << "Warning: allreducer destroyed without calling release()" + << std::endl; + } + } + + void init(sycl::queue& queue, uint32_t rank_in, uint32_t world_in) { + if (initialized) + return; + + if (!load_level_zero_library()) { + throw std::runtime_error("Failed to initialize Level Zero"); + } + + zeCheck_dynamic(zeInit_dynamic(0)); + int tmp_rank, tmp_world; + + tmp_world = world_in; + tmp_rank = rank_in; + + rank = tmp_rank; + world = tmp_world; + initialized = true; + } + void allreduce(sycl::queue& queue, void* inout_buffer, uint32_t size) {} + void release(sycl::queue& queue) { + if (!initialized) + return; + try { + auto l0_ctx = sycl::get_native( + queue.get_context()); + for (int i = 0; i < world; i++) { + if (i != rank) { + zeCheck_dynamic(zeMemCloseIpcHandle_dynamic( + l0_ctx, (char*)buffers[i] - offsets[i])); + } + } + } catch (const std::exception& e) { + std::cerr << "Warning: Level Zero cleanup failed: " << e.what() + << std::endl; + } + sycl::free(buffers[rank], queue); + initialized = false; + } + + void debug_print_buffer(sycl::queue& queue, int* address, int count) { + auto host_ptr = (int*)sycl::malloc_host(count * sizeof(int), queue); + auto tmp_ptr = (int*)sycl::malloc_device(count * sizeof(int), queue); + + queue.memcpy(tmp_ptr, address, count * sizeof(int)); + queue.memcpy(host_ptr, tmp_ptr, count * sizeof(int)); + + queue.wait(); + + for (int i = 0; i < count; i++) { + std::cout << host_ptr[i] << " "; + } + std::cout << std::endl; + } + // buffer_size as element size + void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr) { + if (!load_level_zero_library()) { + throw std::runtime_error("Level Zero not available"); + } + + // Step 1: Get base address of the pointer + sycl::context ctx = queue.get_context(); + auto l0_ctx = sycl::get_native(ctx); + + void* base_addr; + size_t base_size; + zeCheck_dynamic( + zeMemGetAddressRange_dynamic(l0_ctx, ptr, &base_addr, &base_size)); + + // Step 2: Get IPC mem handle from base address + alignas(64) exchange_contents send_buf; + alignas(64) exchange_contents recv_buf[world]; + + // fill in the exchange info + zeCheck_dynamic( + zeMemGetIpcHandle_dynamic(l0_ctx, base_addr, &send_buf.ipc_handle)); + send_buf.offset = (char*)ptr - (char*)base_addr; + + send_buf.pid = getpid(); + + // Step 3: Exchange the handles and offsets + memset(recv_buf, 0, sizeof(recv_buf)); + // Overkill if we don't really needs all peer's handles + un_allgather(&send_buf, recv_buf, rank, world); + for (uint32_t i = 0; i < world; i++) { + // Step 4: Prepare pid file descriptor of next process + auto* peer = recv_buf + i; + // Step 6: Open IPC handle of remote peer + auto l0_device = sycl::get_native( + queue.get_device()); + void* peer_base; + + zeCheck_dynamic(zeMemOpenIpcHandle_dynamic( + l0_ctx, + l0_device, + peer->ipc_handle, + ZE_IPC_MEMORY_FLAG_BIAS_CACHED, + &peer_base)); + + buffers[i] = (char*)peer_base + peer->offset; + // make sure data correction + // debug_print_buffer(queue, static_cast(buffers[i]), + // ELE_COUNT); + offsets[i] = peer->offset; + ipc_handle[i] = send_buf.ipc_handle; + } + } + + bool initialized; + void* buffers[max_rank]; + void* sync_buffer[max_rank]; + size_t offsets[max_rank]; + ze_ipc_mem_handle_t ipc_handle[max_rank]; + int rank, world; + int size_per_buffer; + int data_size_per_buffer; + int buffer_index; +}; diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 5db69b6ea1..9108a96a42 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -1911,7 +1911,7 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { } auto currentStream = at::xpu::getCurrentXPUStream(barDevIdx); - currentStream.synchronize(); + // currentStream.synchronize(); // zl_debug workaround for symm barrier return nullptr; } diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp new file mode 100644 index 0000000000..ae35fa088c --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -0,0 +1,537 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +// todo: fixed with kernel barrier +#include + +#define MAX_RANK 8 + +namespace c10d { +namespace symmetric_memory { + +/* Start of XPUSymmetricMemory implementation */ + +// A set of exchange methods with prefix "XPUSymmetricMemory" +static StoreExchange storeExchange = StoreExchange("XPUSymmetricMemory"); + +AllocationRef::AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx) + : ptr(ptr), + handle(handle), + block_size(block_size), + device_idx(device_idx) {} + +AllocationRef::~AllocationRef() { + if (is_finalizing()) { + return; + } + c10::Device local_device(c10::DeviceType::XPU, device_idx); + c10::DeviceGuard guard(local_device); + c10::xpu::syncStreamsOnDevice(); +} + +XPUSymmetricMemory::XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : alloc_refs_(std::move(alloc_refs)), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + mc_handle_(mc_handle), + mc_addr_(mc_addr), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx); + c10::DeviceGuard guard(local_device); + + at::xpu::getCurrentXPUStream().queue().memcpy( + buffers_dev_, buffers_.data(), arr_size); + at::xpu::getCurrentXPUStream().queue().memcpy( + signal_pads_dev_, signal_pads_.data(), arr_size); +} + +std::vector XPUSymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector XPUSymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** XPUSymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** XPUSymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t XPUSymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t XPUSymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +bool XPUSymmetricMemory::has_multicast_support() { + return mc_addr_ != nullptr; +} + +void* XPUSymmetricMemory::get_multicast_ptr() { + return mc_addr_; +} + +void XPUSymmetricMemory::copy_buffer( + at::Tensor src, + at::Tensor dst, + size_t size) { + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + auto src_ptr = src.data_ptr(); + auto dst_ptr = dst.data_ptr(); + + size_t copy_size = size * c10::elementSize(src.scalar_type()); + + current_queue.memcpy(dst_ptr, src_ptr, copy_size); +} +at::Tensor XPUSymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "XPUSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::XPU, local_device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); +} + +at::Tensor XPUSymmetricMemory::get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) { + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (sizes.size() != 0) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "XPUSymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::XPU, local_device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { + check_channel(channel, world_size_); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); +} + +void XPUSymmetricMemory::put_signal( + int dst_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::put_signal not supported"; +} + +void XPUSymmetricMemory::wait_signal( + int src_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::wait_signal not supported"; +} + +int XPUSymmetricMemory::get_rank() { + return rank_; +} + +int XPUSymmetricMemory::get_world_size() { + return world_size_; +} + +Block::Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name) + : alloc_ref(std::move(alloc_ref)), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + default_group_name(std::move(group_name)) {} + +void* XPUSymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::optional& group_name) { + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + sycl::context sycl_ctx = current_queue.get_context(); + sycl::device sycl_dev = current_queue.get_device(); + ze_context_handle_t ze_ctx = + sycl::get_native(sycl_ctx); + ze_device_handle_t ze_dev = + sycl::get_native(sycl_dev); + + ze_physical_mem_desc_t phys_desc = { + ZE_STRUCTURE_TYPE_PHYSICAL_MEM_DESC, nullptr, 0, block_size}; + + ze_physical_mem_handle_t handle = nullptr; + + ze_device_mem_alloc_desc_t default_device_mem_alloc_desc = { + .stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + .pNext = nullptr, + .flags = 0, + .ordinal = 0}; + + void* ptr = sycl::malloc_device(block_size, current_queue); + current_queue.memset(ptr, 0, block_size); + auto alloc_ref = + c10::make_intrusive(ptr, ptr, block_size, device_idx); + auto block = c10::make_intrusive( + std::move(alloc_ref), + device_idx, + block_size, + size, + signal_pad_offset, + group_name); + + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + // check this ptr copy to sycl buffer + + return ptr; +} + +void XPUSymmetricMemoryAllocator::free(void* ptr) { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); +} + +size_t XPUSymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "XPUSymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via XPUSymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + bool has_multicast_support; + size_t base_offset; +}; + +void validate_rendezvous_requests( + const std::vector& reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + if (!allow_overlapping_devices() && + device_indices.size() < (size_t)world_size) { + TORCH_CHECK( + false, + "XPUSymmetricMemoryAllocator::rendezvous: ", + "detected allocations from overlapping devices ", + "from different ranks."); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +static bool check_group_multicast_support( + const std::vector& reqs) { + std::vector ranks_with_multicast_support; + for (size_t r = 0; r < reqs.size(); ++r) { + if (reqs[r].has_multicast_support) { + ranks_with_multicast_support.push_back(r); + } + } + if (ranks_with_multicast_support.size() == reqs.size()) { + return true; + } else { + // We don't expect this to happen. But we want to let the user to know if + // this happens. + if (ranks_with_multicast_support.size() != 0) { + LOG(WARNING) + << "Only a subset of ranks in the group has multicast support: " + << ranks_with_multicast_support << " (world_size=" << reqs.size() + << "). Skipping multicast initialization because this is unexpected."; + } + return false; + } +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( + void* ptr, + const std::optional& group_name) { + auto block = find_block(ptr); + if (block == nullptr) { + return nullptr; + } + + // The group_name passed to rendezvous() takes precedence over + // the default group_name specified during allocation. + std::string group_name_; + // Treat empty string and std::nullopt the same as empty string seems to be + // implicitly used that way + if (group_name.has_value() && group_name != "") { + group_name_ = *group_name; + } else { + if (!block->default_group_name.has_value()) { + TORCH_CHECK( + false, + "XPUSymmetricMemory::rendezvous: `group_name` is neither " + "specified during allocation nor passed to rendezvous()."); + } + group_name_ = *block->default_group_name; + } + + auto it = block->symm_mems.find(group_name_); + if (it != block->symm_mems.end()) { + return it->second; + } + + c10::Device local_device(c10::DeviceType::XPU, block->device_idx); + c10::DeviceGuard guard(local_device); + + // Currently, IpcChannel is using a file based socket for inter-process + // communication + IpcChannel ipc_channel; + auto group_info = get_group_info(group_name_); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + int block_fd; + + // Step 6: Open IPC handle of remote peer + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + sycl::context ctx = current_queue.get_context(); + auto l0_ctx = sycl::get_native(ctx); + sycl::device dev = current_queue.get_device(); + auto l0_dev = sycl::get_native(dev); + // check with original ones // debug code + // initialize MPI done + allreducer ar; + ar.init(current_queue, rank, world_size); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset, + .has_multicast_support = false, + .base_offset = 0}; + auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); + validate_rendezvous_requests(reqs, world_size); + + std::vector pids(world_size); + for (int r = 0; r < world_size; ++r) { + pids[r] = reqs[r].pid; + } + + // do IPC exchange for all peer ranks + ar.exchange_peer_ipc_mem(current_queue, ptr); + + // auto imported_fds = ipc_channel.all_gather_fds(rank, pids, block_fd); + + std::vector handles(world_size); + std::vector buffers(world_size, nullptr); + std::vector signal_pads(world_size, nullptr); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + handles[r] = block->alloc_ref->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } else { + buffers[r] = ar.buffers[r]; + handles[r] = ar.buffers[r]; // ar.ipc_handle[r]; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + } + } + storeExchange.barrier(store, rank, world_size); + + HandleType mc_handle{}; + void* mc_addr = nullptr; + bool group_has_multicast_support = check_group_multicast_support(reqs); + // todo: not support multicast now + std::vector> alloc_refs; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + alloc_refs.emplace_back(block->alloc_ref); + continue; + } + alloc_refs.push_back(c10::make_intrusive( + buffers[r], handles[r], block->block_size, block->device_idx)); + } + + auto symm_mem = c10::make_intrusive( + std::move(alloc_refs), + std::move(buffers), + std::move(signal_pads), + mc_handle, + mc_addr, + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + block->symm_mems[group_name_] = symm_mem; + return symm_mem; +} + +bool XPUSymmetricMemoryAllocator::has_multicast_support(int device_idx) { + return false; +} + +c10::DeviceType XPUSymmetricMemoryAllocator::supported_device_type() { + return c10::DeviceType::XPU; +} + +std::string XPUSymmetricMemoryAllocator::name() { + return "XPU"; +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterXPUSymmetricMemoryAllocator { + RegisterXPUSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for XPU + if (getSymmMemBackendXPU() == "XPU") { + // Direct set (static registration) + register_allocator(c10::DeviceType::XPU, allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("XPU", allocator); + } + } +}; +static RegisterXPUSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp new file mode 100644 index 0000000000..aa7f1c1660 --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon +// destruction, it unmaps the vaddr and releases the allocation handle. +struct AllocationRef : public c10::intrusive_ptr_target { + void* ptr; + HandleType handle; + size_t block_size; + int device_idx; + + AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx); + + ~AllocationRef(); +}; + +class XPUSymmetricMemory : public SymmetricMemory { + public: + XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~XPUSymmetricMemory() override{}; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + bool has_multicast_support() override; + void* get_multicast_ptr() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) override; + + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override; + + void barrier(int channel, size_t timeout_ms) override; + void put_signal(int dst_rank, int channel, size_t timeout_ms) override; + void wait_signal(int src_rank, int channel, size_t timeout_ms) override; + void copy_buffer(at::Tensor src, at::Tensor dst, size_t size) override; + + int get_rank() override; + int get_world_size() override; + + private: + std::vector> alloc_refs_; + std::vector buffers_; + std::vector signal_pads_; + HandleType mc_handle_; + void* mc_addr_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; +}; + +struct Block : public c10::intrusive_ptr_target { + c10::intrusive_ptr alloc_ref; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::optional default_group_name; + std::map> symm_mems; + + Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name); +}; + +class XPUSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override; + + void free(void* ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override; + bool has_multicast_support(int device_idx) override; + // void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr); + c10::DeviceType supported_device_type() override; + std::string name() override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryTypes.hpp b/src/xccl/XPUSymmetricMemoryTypes.hpp new file mode 100644 index 0000000000..133abd2712 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryTypes.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace c10d::symmetric_memory { + +constexpr size_t signal_pad_size = 2048; +using HandleType = void*; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp new file mode 100644 index 0000000000..551e12abc5 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -0,0 +1,247 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +std::string getSymmMemBackendXPU() { + static auto val = c10::utils::get_env("TORCH_SYMMMEM"); + if (val.has_value()) { + TORCH_CHECK( + val.value() == "XPU", + "TORCH_SYMMMEM environment variable must be 'XPU'."); + return val.value(); + } + return "XPU"; +} + +bool device_has_multicast_support(int device_idx) { + return false; +} + +bool allow_overlapping_devices() { + return c10::utils::check_env("TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES") == + true; +} + +IpcChannel::IpcChannel() + : socket_name_(get_socket_name(getpid())), + socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) { + // On success, a file descriptor for the new socket is returned. + // On error, -1 is returned, and errno is set to indicate the error. + TORCH_CHECK( + socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno)); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + TORCH_CHECK( + bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0, + "Failed to bind socket: ", + c10::utils::str_error(errno)); +} + +IpcChannel::~IpcChannel() { + close(socket_); + unlink(socket_name_.c_str()); +} + +void IpcChannel::send_fd(int dst_pid, int fd) { + // Because file descriptors are process-local kernel objects, and we can’t + // pass them via normal socket payloads (like write() or send()). Unix domain + // sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg(). + // Define destination socket address + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + auto socket_name = get_socket_name(dst_pid); + std::copy(socket_name.begin(), socket_name.end(), addr.sun_path); + + // Prepare data to send + // Data being sent is "fd", the value of fd will be sent as auxiliary data + // (control message) + struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2}; + + // Prepare control message data buffer and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); + + // Create message header + struct msghdr msg { + // destination socket address and size of it + // message content in msg_iov and number of such structs (1 in our case) + // auxiliary data with the value of fd and size of it + .msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf, + .msg_controllen = sizeof(cbuf) + }; + + // This points to the first control message header + // With SCM_RIGHTS we let the kernel know that we are passing file + // descriptors. + auto cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + // Specify socket level message + cmsg->cmsg_level = SOL_SOCKET; + // SCM_RIGHTS is the type used to pass file descriptors + cmsg->cmsg_type = SCM_RIGHTS; + + if (fd != -1) { + std::copy( + reinterpret_cast(&fd), + reinterpret_cast(&fd) + sizeof(fd), + reinterpret_cast(CMSG_DATA(cmsg))); + } else { + msg.msg_controllen = 0; + } + + // Finally send the the message + TORCH_CHECK( + sendmsg(socket_, &msg, 0) > 0, + "Failed to send fd: ", + c10::utils::str_error(errno)); +} + +int IpcChannel::recv_fd() { + // Prepare buffer for regular message "fd" + // NOLINTNEXTLINE(*array*) + char buf[2]; + memset(&buf, 0, sizeof(buf)); + struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)}; + + // Prepare buffer for control message and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); + + // Define socket address to receive on: family AF_UNIX means unix domain + // socket + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + // Prepare message header + struct msghdr msg = { + .msg_name = (void*)&addr, + .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, + .msg_iovlen = 1, + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf)}; + + // Recieve message on socket_ + TORCH_CHECK( + recvmsg(socket_, &msg, 0) > 0, + "Failed to receive fd: ", + c10::utils::str_error(errno)); + + if (msg.msg_controllen == 0) { + return -1; + } + + // Extract control message and validate its content + auto cmsg = CMSG_FIRSTHDR(&msg); + TORCH_CHECK(cmsg != nullptr); + TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int))); + TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS); + return *reinterpret_cast(CMSG_DATA(cmsg)); +} + +std::vector IpcChannel::all_gather_fds( + int rank, + const std::vector& pids, + int fd) { + int world_size = (int)pids.size(); + std::vector fds(pids.size()); + fds[rank] = fd; + + int dst_rank = (rank + 1) % world_size; + for (int step = 1; step < world_size; ++step) { + int src_rank = (rank + world_size - step) % world_size; + send_fd(pids[dst_rank], fd); + fd = recv_fd(); + fds[src_rank] = fd; + } + return fds; +} + +int IpcChannel::broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd) { + int world_size = (int)pids.size(); + + if (rank == src_rank) { + for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) { + if (dst_rank == rank) { + continue; + } + send_fd(pids[dst_rank], fd); + } + return fd; + } + return recv_fd(); +} + +std::string IpcChannel::get_socket_name(int pid) { + const char* tmp_dir = "/tmp"; + for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) { + if (const char* path = getenv(env_var)) { + tmp_dir = path; + break; + } + } + std::ostringstream oss; + oss << tmp_dir << "/symm_mem-" << pid; + return oss.str(); +} + +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx) { + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + sycl::context sycl_ctx = current_queue.get_context(); + ze_context_handle_t ze_context = + sycl::get_native(sycl_ctx); + // 1. Reserve virtual address space + void* virtual_ptr = nullptr; + ze_result_t status = zeVirtualMemReserve( + ze_context, // context + nullptr, // let L0 pick virtual address + size, // size + &virtual_ptr // out: reserved address + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemReserve failed"); + + // 2. Map physical memory to virtual address + status = zeVirtualMemMap( + ze_context, + virtual_ptr, // virtual memory to map to + size, + handle, // physical memory handle + 0, // flags + ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE // ze_memory_access_attribute_t + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemMap failed"); + + // 3. Set access attributes + ze_memory_access_attribute_t access = ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE; + status = + zeVirtualMemSetAccessAttribute(ze_context, virtual_ptr, size, access); + TORCH_CHECK( + status == ZE_RESULT_SUCCESS, "zeVirtualMemSetAccessAttribute failed"); + + // 4. Return pointer + *ptr = virtual_ptr; +} + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp new file mode 100644 index 0000000000..f119928f41 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -0,0 +1,113 @@ +#pragma once +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +std::string getSymmMemBackendXPU(); + +bool device_has_multicast_support(int device_idx); + +bool allow_overlapping_devices(); + +class IpcChannel { + public: + IpcChannel(); + ~IpcChannel(); + + void send_fd(int dst_pid, int fd); + int recv_fd(); + + std::vector all_gather_fds( + int rank, + const std::vector& pids, + int fd); + + int broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd); + + private: + static std::string get_socket_name(int pid); + + std::string socket_name_; + int socket_; +}; + +// A set of store-based exchange methods with a preset prefix typically type of +// the SymmetricMemory. Most used as static instances at respective +// SymmetricMemory implementation files. +class StoreExchange { + public: + StoreExchange(const std::string& store_prefix) + : store_prefix_(store_prefix) {} + + // Put template function in header file so that compiler can easily access it. + template + std::vector all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + peer_keys.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_prefix_ << "/" << seq_id_ << "/" << r; + peer_keys.push_back(oss.str()); + } + ++seq_id_; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + peer_vals.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; + } + + void barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + // TODO: implement an efficient one? + all_gather(store, rank, world_size, 0); + } + + private: + const std::string store_prefix_; + size_t seq_id_ = 0; +}; + +// Teturns a pointer of virtual address that is mapped to the physical memory +// held by the handle. +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/ze_exception.hpp b/src/xccl/ze_exception.hpp new file mode 100644 index 0000000000..99e4cb6e9e --- /dev/null +++ b/src/xccl/ze_exception.hpp @@ -0,0 +1,254 @@ +#pragma once + +#include +#include +#include +#include + +#define zeVirtualMemMap zeVirtualMemMap_original +#define zeVirtualMemReserve zeVirtualMemReserve_original +#define zeVirtualMemSetAccessAttribute zeVirtualMemSetAccessAttribute_original + +#include + +#undef zeVirtualMemMap +#undef zeVirtualMemReserve +#undef zeVirtualMemSetAccessAttribute + +typedef ze_result_t (*zeInit_t)(ze_init_flags_t flags); +typedef ze_result_t (*zeMemGetAddressRange_t)( + ze_context_handle_t hContext, + const void* ptr, + void** pBase, + size_t* pSize); +typedef ze_result_t (*zeMemGetIpcHandle_t)( + ze_context_handle_t hContext, + const void* ptr, + ze_ipc_mem_handle_t* pIpcHandle); +typedef ze_result_t (*zeMemOpenIpcHandle_t)( + ze_context_handle_t hContext, + ze_device_handle_t hDevice, + ze_ipc_mem_handle_t handle, + ze_ipc_memory_flags_t flags, + void** pptr); +typedef ze_result_t ( + *zeMemCloseIpcHandle_t)(ze_context_handle_t hContext, const void* ptr); +typedef ze_result_t (*zeVirtualMemMap_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access); +typedef ze_result_t (*zeVirtualMemReserve_t)( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr); +typedef ze_result_t (*zeVirtualMemSetAccessAttribute_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access); + +bool load_level_zero_library(); +void unload_level_zero_library(); + +#define zeCheck_dynamic(x) \ + do { \ + if (!load_level_zero_library()) { \ + throw std::runtime_error("Level Zero library not available"); \ + } \ + ze_result_t result = (x); \ + if (result != ZE_RESULT_SUCCESS) { \ + auto e = zeException(result); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } \ + } while (0) + +#define zeInit_dynamic(flags) zeInit_ptr(flags) +#define zeMemGetAddressRange_dynamic(ctx, ptr, base, size) \ + zeMemGetAddressRange_ptr(ctx, ptr, base, size) +#define zeMemGetIpcHandle_dynamic(ctx, ptr, handle) \ + zeMemGetIpcHandle_ptr(ctx, ptr, handle) +#define zeMemOpenIpcHandle_dynamic(ctx, dev, handle, flags, ptr) \ + zeMemOpenIpcHandle_ptr(ctx, dev, handle, flags, ptr) +#define zeMemCloseIpcHandle_dynamic(ctx, ptr) zeMemCloseIpcHandle_ptr(ctx, ptr) +#define zeVirtualMemMap_dynamic(ctx, ptr, size, phys_mem, offset, access) \ + zeVirtualMemMap_ptr(ctx, ptr, size, phys_mem, offset, access) +#define zeVirtualMemReserve_dynamic(ctx, start, size, ptr) \ + zeVirtualMemReserve_ptr(ctx, start, size, ptr) +#define zeVirtualMemSetAccessAttribute_dynamic(ctx, ptr, size, access) \ + zeVirtualMemSetAccessAttribute_ptr(ctx, ptr, size, access) + +// Exception handling class +class zeException : std::exception { + const char* zeResultToString(ze_result_t status) const { + static const std::unordered_map zeResultToStringMap{ + {ZE_RESULT_SUCCESS, "[Core] success"}, + {ZE_RESULT_NOT_READY, "[Core] synchronization primitive not signaled"}, + {ZE_RESULT_ERROR_UNINITIALIZED, + "[Validation] driver is not initialized"}, + {ZE_RESULT_ERROR_INVALID_NULL_POINTER, + "[Validation] pointer argument may not be nullptr"}, + {ZE_RESULT_ERROR_INVALID_NULL_HANDLE, + "[Validation] handle argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_ENUMERATION, + "[Validation] enumerator argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_SIZE, "[Validation] size argument is invalid"}, + {ZE_RESULT_ERROR_UNSUPPORTED_SIZE, + "[Validation] size argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT, + "[Validation] alignment argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, + "[Validation] generic error code for unsupported features"}, + {ZE_RESULT_ERROR_INVALID_NATIVE_BINARY, + "[Validation] native binary is not supported by the device"}, + {ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY, + "[Core] insufficient host memory to satisfy call"}, + {ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY, + "[Core] insufficient device memory to satisfy call"}, + {ZE_RESULT_ERROR_DEVICE_LOST, + "[Core] device hung, reset, was removed, or driver update occurred"}, + {ZE_RESULT_ERROR_MODULE_BUILD_FAILURE, + "[Core] error occurred when building module, see build log for details"}, + {ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE, + "[Validation] object pointed to by handle still in-use by device"}, + }; + auto it = zeResultToStringMap.find(status); + if (it != zeResultToStringMap.end()) + return it->second; + else + return "Unknown Reason"; + } + + public: + zeException(ze_result_t ret) : result_(ret) {} + + ze_result_t result_; + + const char* what() const noexcept override { + return zeResultToString(result_); + } +}; + +#define zeCheck(x) \ + if (x != ZE_RESULT_SUCCESS) { \ + auto e = zeException(x); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } + +static zeInit_t zeInit_ptr = nullptr; +static zeMemGetAddressRange_t zeMemGetAddressRange_ptr = nullptr; +static zeMemGetIpcHandle_t zeMemGetIpcHandle_ptr = nullptr; +static zeMemOpenIpcHandle_t zeMemOpenIpcHandle_ptr = nullptr; +static zeMemCloseIpcHandle_t zeMemCloseIpcHandle_ptr = nullptr; +static zeVirtualMemMap_t zeVirtualMemMap_ptr = nullptr; +static zeVirtualMemReserve_t zeVirtualMemReserve_ptr = nullptr; +static zeVirtualMemSetAccessAttribute_t zeVirtualMemSetAccessAttribute_ptr = + nullptr; + +static void* ze_handle = nullptr; + +inline bool load_level_zero_library() { + if (ze_handle != nullptr) { + return true; + } + const char* lib_names[] = {"/usr/lib/x86_64-linux-gnu/libze_loader.so"}; + + for (const char* lib_name : lib_names) { + ze_handle = dlopen(lib_name, RTLD_LAZY); + if (ze_handle != nullptr) { + break; + } + } + + if (ze_handle == nullptr) { + std::cerr << "Failed to load Level Zero library: " << dlerror() + << std::endl; + return false; + } + + zeInit_ptr = (zeInit_t)dlsym(ze_handle, "zeInit"); + zeMemGetAddressRange_ptr = + (zeMemGetAddressRange_t)dlsym(ze_handle, "zeMemGetAddressRange"); + zeMemGetIpcHandle_ptr = + (zeMemGetIpcHandle_t)dlsym(ze_handle, "zeMemGetIpcHandle"); + zeMemOpenIpcHandle_ptr = + (zeMemOpenIpcHandle_t)dlsym(ze_handle, "zeMemOpenIpcHandle"); + zeMemCloseIpcHandle_ptr = + (zeMemCloseIpcHandle_t)dlsym(ze_handle, "zeMemCloseIpcHandle"); + zeVirtualMemMap_ptr = (zeVirtualMemMap_t)dlsym(ze_handle, "zeVirtualMemMap"); + zeVirtualMemReserve_ptr = + (zeVirtualMemReserve_t)dlsym(ze_handle, "zeVirtualMemReserve"); + zeVirtualMemSetAccessAttribute_ptr = (zeVirtualMemSetAccessAttribute_t)dlsym( + ze_handle, "zeVirtualMemSetAccessAttribute"); + + if (!zeInit_ptr || !zeMemGetAddressRange_ptr || !zeMemGetIpcHandle_ptr || + !zeMemOpenIpcHandle_ptr || !zeMemCloseIpcHandle_ptr || + !zeVirtualMemMap_ptr || !zeVirtualMemReserve_ptr || + !zeVirtualMemSetAccessAttribute_ptr) { + std::cerr << "Failed to load Level Zero API functions" << std::endl; + dlclose(ze_handle); + ze_handle = nullptr; + return false; + } + + return true; +} + +inline void unload_level_zero_library() { + if (ze_handle != nullptr) { + dlclose(ze_handle); + ze_handle = nullptr; + zeInit_ptr = nullptr; + zeMemGetAddressRange_ptr = nullptr; + zeMemGetIpcHandle_ptr = nullptr; + zeMemOpenIpcHandle_ptr = nullptr; + zeMemCloseIpcHandle_ptr = nullptr; + zeVirtualMemMap_ptr = nullptr; + zeVirtualMemReserve_ptr = nullptr; + zeVirtualMemSetAccessAttribute_ptr = nullptr; + } +} + +extern "C" { + +__attribute__((weak)) ze_result_t zeVirtualMemMap( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemMap_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemMap_ptr( + hContext, ptr, size, hPhysicalMemory, offset, access); +} + +__attribute__((weak)) ze_result_t zeVirtualMemReserve( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr) { + if (!load_level_zero_library() || !zeVirtualMemReserve_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemReserve_ptr(hContext, pStart, size, pptr); +} + +__attribute__((weak)) ze_result_t zeVirtualMemSetAccessAttribute( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemSetAccessAttribute_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemSetAccessAttribute_ptr(hContext, ptr, size, access); +} +}