diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index b99a24cfd..9f3f656a6 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -21,6 +21,9 @@ #include #ifdef USE_MNNVL +#include // pytorch dependencies +#include + #include "transport/nvlink_transport/nvlink_transport.h" static void *allocateMemory(size_t size) { return mooncake::NvlinkTransport::allocatePinnedLocalMemory(size); @@ -28,9 +31,18 @@ static void *allocateMemory(size_t size) { static void freeMemory(void *ptr) { mooncake::NvlinkTransport::freePinnedLocalMemory(ptr); } +static void synchronizePyTorchEvents() { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + cudaEvent_t ev; + cudaEventCreate(&ev); + cudaEventRecord(ev, stream); + cudaEventSynchronize(ev); + cudaEventDestroy(ev); +} #else static void *allocateMemory(size_t size) { return malloc(size); } static void freeMemory(void *ptr) { free(ptr); } +static void synchronizePyTorchEvents() {} #endif TransferEnginePy::TransferEnginePy() { @@ -258,6 +270,7 @@ int TransferEnginePy::transferSync(const char *target_hostname, uintptr_t peer_buffer_address, size_t length, TransferOpcode opcode) { pybind11::gil_scoped_release release; + synchronizePyTorchEvents(); Transport::SegmentHandle handle; { std::lock_guard guard(mutex_); @@ -331,6 +344,7 @@ int TransferEnginePy::batchTransferSync( std::vector peer_buffer_addresses, std::vector lengths, TransferOpcode opcode) { pybind11::gil_scoped_release release; + synchronizePyTorchEvents(); Transport::SegmentHandle handle; { std::lock_guard guard(mutex_); @@ -419,6 +433,7 @@ batch_id_t TransferEnginePy::batchTransferAsync( const std::vector &peer_buffer_addresses, const std::vector &lengths, TransferOpcode opcode) { pybind11::gil_scoped_release release; + synchronizePyTorchEvents(); Transport::SegmentHandle handle; { std::lock_guard guard(mutex_); @@ -541,6 +556,7 @@ batch_id_t TransferEnginePy::transferSubmitWrite(const char *target_hostname, uintptr_t peer_buffer_address, size_t length) { pybind11::gil_scoped_release release; + synchronizePyTorchEvents(); Transport::SegmentHandle handle; { std::lock_guard guard(mutex_);