Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,28 @@
#include <pybind11/stl.h>

#ifdef USE_MNNVL
#include <ATen/cuda/CUDAContext.h> // pytorch dependencies
#include <cuda_runtime.h>

#include "transport/nvlink_transport/nvlink_transport.h"
static void *allocateMemory(size_t size) {
return mooncake::NvlinkTransport::allocatePinnedLocalMemory(size);
}
static void freeMemory(void *ptr) {
mooncake::NvlinkTransport::freePinnedLocalMemory(ptr);
}
static void synchronizePyTorchEvents() {
auto stream = at::cuda::getCurrentCUDAStream().stream();
cudaEvent_t ev;
cudaEventCreate(&ev);
cudaEventRecord(ev, stream);
cudaEventSynchronize(ev);
cudaEventDestroy(ev);
}
#else
static void *allocateMemory(size_t size) { return malloc(size); }
static void freeMemory(void *ptr) { free(ptr); }
static void synchronizePyTorchEvents() {}
#endif

TransferEnginePy::TransferEnginePy() {
Expand Down Expand Up @@ -258,6 +270,7 @@ int TransferEnginePy::transferSync(const char *target_hostname,
uintptr_t peer_buffer_address, size_t length,
TransferOpcode opcode) {
pybind11::gil_scoped_release release;
synchronizePyTorchEvents();
Transport::SegmentHandle handle;
{
std::lock_guard<std::mutex> guard(mutex_);
Expand Down Expand Up @@ -331,6 +344,7 @@ int TransferEnginePy::batchTransferSync(
std::vector<uintptr_t> peer_buffer_addresses, std::vector<size_t> lengths,
TransferOpcode opcode) {
pybind11::gil_scoped_release release;
synchronizePyTorchEvents();
Transport::SegmentHandle handle;
{
std::lock_guard<std::mutex> guard(mutex_);
Expand Down Expand Up @@ -419,6 +433,7 @@ batch_id_t TransferEnginePy::batchTransferAsync(
const std::vector<uintptr_t> &peer_buffer_addresses,
const std::vector<size_t> &lengths, TransferOpcode opcode) {
pybind11::gil_scoped_release release;
synchronizePyTorchEvents();
Transport::SegmentHandle handle;
{
std::lock_guard<std::mutex> guard(mutex_);
Expand Down Expand Up @@ -541,6 +556,7 @@ batch_id_t TransferEnginePy::transferSubmitWrite(const char *target_hostname,
uintptr_t peer_buffer_address,
size_t length) {
pybind11::gil_scoped_release release;
synchronizePyTorchEvents();
Transport::SegmentHandle handle;
{
std::lock_guard<std::mutex> guard(mutex_);
Expand Down
Loading