diff --git a/.gitignore b/.gitignore index c2568c1af..d0ad29f0f 100644 --- a/.gitignore +++ b/.gitignore @@ -197,4 +197,4 @@ mooncake-wheel/mooncake/mooncake_master mooncake-wheel/mooncake/transfer_engine_bench # Claude Code Memory -CLAUDE.md +CLAUDE.md \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 13179e792..72b9acbee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,24 @@ option(WITH_P2P_STORE "build p2p store library and sample code" OFF) option(WITH_RUST_EXAMPLE "build the Rust interface and sample code for the transfer engine" OFF) option(WITH_EP "build mooncake with expert parallelism support" OFF) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extern/pybind11) +# pybind11: prefer system/finder (>=2.13), then vendored extern, finally FetchContent +find_package(pybind11 2.13 CONFIG QUIET) + +if (NOT pybind11_FOUND) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/extern/pybind11/CMakeLists.txt) + message(STATUS "Using vendored pybind11 from extern/pybind11") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extern/pybind11) + else() + include(FetchContent) + message(STATUS "Fetching pybind11 (v2.13.6) via FetchContent") + FetchContent_Declare(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.13.6 + GIT_SHALLOW TRUE + ) + FetchContent_MakeAvailable(pybind11) + endif() +endif() set(PYTHON_EXECUTABLE "python3") execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import sys; print(sys.path[-1])" diff --git a/mooncake-integration/CMakeLists.txt b/mooncake-integration/CMakeLists.txt index 5a79ff087..76819e514 100644 --- a/mooncake-integration/CMakeLists.txt +++ b/mooncake-integration/CMakeLists.txt @@ -20,7 +20,9 @@ endif() include_directories("/usr/include/jsoncpp") include_directories("./") +include_directories("../mooncake-transfer-engine/include") +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) @@ -30,11 +32,17 @@ message("${PYTHON_SYS_PATH}") set(PYTHON_PACKAGE_NAME "mooncake") pybind11_add_module(engine ${SOURCES} ${CACHE_ALLOCATOR_SOURCES} - transfer_engine/transfer_engine_py.cpp + transfer_engine/transfer_engine_py.cpp + $ +) + +target_include_directories(engine PRIVATE + ${Python3_INCLUDE_DIRS} ) target_link_libraries(engine PUBLIC transfer_engine + coro_rpc_connector glog::glog gflags::gflags ) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 316a4d109..499cb961f 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -20,6 +20,11 @@ #include +// Include coro_rpc_interface headers +#include "transport/coro_rpc_connector/cororpc_interface.h" + +using namespace pybind11::literals; + #ifdef USE_MNNVL #include "transport/nvlink_transport/nvlink_transport.h" static void *allocateMemory(size_t size) { @@ -669,6 +674,59 @@ std::vector TransferEnginePy::getNotifies() { namespace py = pybind11; +// Forward declaration for coro_rpc_interface binding function +void bind_coro_rpc_interface(py::module_ &m); + +// Implementation of coro_rpc_interface binding function +void bind_coro_rpc_interface(py::module_ &m) { + using namespace mooncake; + + py::class_(m, "ReceivedData") + .def(py::init<>()) + .def_readonly("source_address", + &CoroRPCInterface::ReceivedData::source_address) + .def_readonly("data_size", &CoroRPCInterface::ReceivedData::data_size) + .def("get_bytes", &CoroRPCInterface::ReceivedData::getBytes); + + py::class_(m, "ReceivedTensor") + .def(py::init<>()) + .def_readonly("source_address", + &CoroRPCInterface::ReceivedTensor::source_address) + .def_readonly("shape", &CoroRPCInterface::ReceivedTensor::shape) + .def_readonly("dtype", &CoroRPCInterface::ReceivedTensor::dtype) + .def_readonly("total_bytes", + &CoroRPCInterface::ReceivedTensor::total_bytes) + .def("get_data_size", &CoroRPCInterface::ReceivedTensor::getDataSize) + .def("get_data_as_bytes", + &CoroRPCInterface::ReceivedTensor::getDataAsBytes); + + py::class_(m, "CoroRPCInterface") + .def(py::init<>()) + .def("initialize", &CoroRPCInterface::initialize, + "listen_address"_a = "", "thread_count"_a = 0, + "timeout_seconds"_a = 30, "pool_size"_a = 10) + .def("initialize_client", &CoroRPCInterface::initializeClient, + "pool_size"_a = 10, "timeout_seconds"_a = 30) + .def("initialize_server", &CoroRPCInterface::initializeServer, + "listen_address"_a, "thread_count"_a = 8, "timeout_seconds"_a = 30) + .def("start_server", &CoroRPCInterface::startServer) + .def("start_server_async", &CoroRPCInterface::startServerAsync) + .def("stop_server", &CoroRPCInterface::stopServer) + .def("send_data", &CoroRPCInterface::sendData) + .def("send_data_async", &CoroRPCInterface::sendDataAsync) + .def("send_tensor", &CoroRPCInterface::sendTensor) + .def("send_tensor_async", &CoroRPCInterface::sendTensorAsync) + .def("set_data_receive_callback", + &CoroRPCInterface::setDataReceiveCallback) + .def("set_tensor_receive_callback", + &CoroRPCInterface::setTensorReceiveCallback); + + m.def("create_rpc_client", &createRPCClient, "pool_size"_a = 10, + "timeout_seconds"_a = 30); + m.def("create_rpc_server", &createRPCServer, "listen_address"_a, + "thread_count"_a = 0); +} + PYBIND11_MODULE(engine, m) { py::enum_ transfer_opcode( m, "TransferOpcode", py::arithmetic()); @@ -731,6 +789,8 @@ PYBIND11_MODULE(engine, m) { adaptor_cls.attr("TransferOpcode") = transfer_opcode; + // Bind coro_rpc_interface directly to the main module + bind_coro_rpc_interface(m); py::class_>( m, "InnerTransferEngine"); } diff --git a/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_communicator.h b/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_communicator.h new file mode 100644 index 000000000..8e3f89f25 --- /dev/null +++ b/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_communicator.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mooncake { + +struct TensorInfo { + void* data_ptr = nullptr; + std::vector shape; + std::string dtype; + size_t total_bytes = 0; +}; + +struct result { + int code = 0; + std::string err_msg; +}; + +struct Config { + std::string listen_address; + size_t thread_count = 0; + size_t timeout_seconds = 30; + size_t pool_size = 10; +}; + +template +struct SimpleContext { + coro_rpc::context context_; + void response_msg() { context_.response_msg(); } +}; + +class CoroRPCCommunicator { + public: + class Impl { + public: + Config config; + bool is_server_started = false; + + std::unique_ptr server_; + + std::function + data_receive_callback; + + pybind11::handle py_callback; + + void handleDataTransfer(coro_rpc::context context, + std::string_view data); + void handleTensorTransfer(coro_rpc::context context); + void handleDataTransferWithAttachment(coro_rpc::context context, + std::string_view data); + void handleTensorTransferWithAttachment( + coro_rpc::context context); + }; + + CoroRPCCommunicator(); + ~CoroRPCCommunicator(); + + bool initialize(const Config& config); + bool startServerImpl(bool is_async = true); + bool startServer(); + bool startServerAsync(); + void stopServer(); + + int sendData(const std::string& target_address, const void* data, + size_t data_size); + async_simple::coro::Lazy sendDataAsync( + const std::string& target_address, const void* data, size_t data_size); + + int sendTensor(const std::string& target_address, + const pybind11::object& tensor); + async_simple::coro::Lazy sendTensorAsync( + const std::string& target_address, const TensorInfo& tensor); + + int receiveData(const std::string& source_address, void* buffer, + size_t buffer_size, int timeout_ms = -1); + async_simple::coro::Lazy receiveDataAsync( + const std::string& source_address, int timeout_ms = -1); + + void setDataReceiveCallback( + std::function callback); + + std::shared_ptr getImpl() { return impl_; } + + private: + std::shared_ptr> + client_pools_; + std::shared_ptr impl_; +}; + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_interface.h b/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_interface.h new file mode 100644 index 000000000..866b1e97f --- /dev/null +++ b/mooncake-transfer-engine/include/transport/coro_rpc_connector/cororpc_interface.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace mooncake { + +struct Config; + +class CoroRPCInterface { + public: + struct ReceivedData { + std::string source_address; + std::string data; + size_t data_size = 0; + + pybind11::bytes getBytes() const { return pybind11::bytes(data); } + pybind11::memoryview getMemoryView() const { + return pybind11::memoryview::from_memory( + const_cast(data.data()), data.size(), true); + } + }; + + struct ReceivedTensor { + std::string source_address; + std::string data; + std::vector shape; + std::string dtype; + size_t total_bytes = 0; + size_t getDataSize() const { return data.size(); } + pybind11::bytes getDataAsBytes() const { return pybind11::bytes(data); } + pybind11::memoryview getMemoryView() const { + return pybind11::memoryview::from_memory( + const_cast(data.data()), data.size(), true); + } + }; + + class Impl; + + CoroRPCInterface(); + ~CoroRPCInterface(); + + bool initialize(const std::string& listen_address = "", + size_t thread_count = 0, size_t timeout_seconds = 30, + size_t pool_size = 10); + + // Convenience methods for common use cases + bool initializeClient(size_t pool_size = 10, size_t timeout_seconds = 30); + bool initializeServer(const std::string& listen_address, + size_t thread_count = 8, size_t timeout_seconds = 30); + + bool startServer(); + bool startServerAsync(); + bool startServerImpl(bool is_async = true); + void stopServer(); + + int sendData(const std::string& target_address, pybind11::handle data); + pybind11::object sendDataAsync(std::string& target_address, + pybind11::handle data, + pybind11::handle loop); + + int sendTensor(const std::string& target_address, pybind11::handle tensor); + pybind11::object sendTensorAsync(std::string& target_address, + pybind11::handle tensor, + pybind11::handle loop); + + void setDataReceiveCallback(pybind11::function callback); + void setTensorReceiveCallback(pybind11::function callback); + + void handleIncomingData(std::string_view source_address, + std::string_view data); + void handleIncomingTensor(std::string_view source_address, + std::string_view data, + const std::vector& shape, + std::string_view dtype); + + private: + std::unique_ptr impl_; +}; + +std::unique_ptr createRPCClient(uint64_t local_rank = 0, + uint64_t world_size = 1); +std::unique_ptr createRPCServer(uint64_t local_rank = 0, + uint64_t world_size = 1); + +} // namespace mooncake + +namespace pybind11 { +class module_; +} +void bind_coro_rpc_interface(pybind11::module_& m); \ No newline at end of file diff --git a/mooncake-transfer-engine/src/CMakeLists.txt b/mooncake-transfer-engine/src/CMakeLists.txt index 769359bfc..79f3b0d15 100644 --- a/mooncake-transfer-engine/src/CMakeLists.txt +++ b/mooncake-transfer-engine/src/CMakeLists.txt @@ -2,6 +2,21 @@ file(GLOB ENGINE_SOURCES "*.cpp") add_subdirectory(common) add_subdirectory(transport) +# Find Python for pybind11 support +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +find_package(pybind11 QUIET) +if(NOT pybind11_FOUND) + execute_process( + COMMAND ${Python3_EXECUTABLE} -m pybind11 --cmakedir + OUTPUT_VARIABLE pybind11_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE pybind11_RESULT + ) + if(pybind11_RESULT EQUAL 0) + find_package(pybind11 REQUIRED PATHS ${pybind11_DIR}) + endif() +endif() + SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) add_library(transfer_engine ${ENGINE_SOURCES} $) @@ -36,9 +51,14 @@ endif() target_link_libraries( transfer_engine PUBLIC - base transport rdma_transport ibverbs glog::glog gflags::gflags pthread JsonCpp::JsonCpp numa yalantinglibs::yalantinglibs + base transport rdma_transport ibverbs glog::glog gflags::gflags pthread JsonCpp::JsonCpp numa yalantinglibs::yalantinglibs ${Python3_LIBRARIES} ) +# Add pybind11 headers if pybind11 is available +if(TARGET pybind11::headers) + target_link_libraries(transfer_engine PUBLIC pybind11::headers) +endif() + if (USE_CUDA) target_include_directories(transfer_engine PRIVATE /usr/local/cuda/include) target_link_libraries(transfer_engine PUBLIC cuda cudart rt) diff --git a/mooncake-transfer-engine/src/transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/CMakeLists.txt index 5517a5ddc..741f256fd 100644 --- a/mooncake-transfer-engine/src/transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/CMakeLists.txt @@ -1,8 +1,19 @@ file(GLOB XPORT_SOURCES "*.cpp") +# Find Python - pybind11 is already configured at the root level +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + add_subdirectory(rdma_transport) -add_library(transport OBJECT ${XPORT_SOURCES} $) -target_link_libraries(transport PRIVATE JsonCpp::JsonCpp yalantinglibs::yalantinglibs glog::glog pthread) +add_subdirectory(coro_rpc_connector) + +add_library(transport OBJECT ${XPORT_SOURCES} $ $) +target_link_libraries(transport PRIVATE JsonCpp::JsonCpp yalantinglibs::yalantinglibs glog::glog pthread ${Python3_LIBRARIES}) +target_include_directories(transport PRIVATE ${Python3_INCLUDE_DIRS}) + +# Add pybind11 headers if available +if(TARGET pybind11::headers) + target_link_libraries(transport PRIVATE pybind11::headers) +endif() if (USE_TCP) add_subdirectory(tcp_transport) @@ -32,4 +43,4 @@ endif() if (USE_MNNVL) add_subdirectory(nvlink_transport) target_sources(transport PUBLIC $) -endif() +endif() \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/coro_rpc_connector/CMakeLists.txt b/mooncake-transfer-engine/src/transport/coro_rpc_connector/CMakeLists.txt new file mode 100644 index 000000000..58ebc4351 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/coro_rpc_connector/CMakeLists.txt @@ -0,0 +1,60 @@ +file(GLOB CORO_RPC_SOURCES "*.cpp") + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + +if (NOT TARGET pybind11::headers) + find_package(pybind11 CONFIG QUIET) + if (NOT pybind11_FOUND) + include(FetchContent) + message(STATUS "Fetching pybind11 (v2.13.6) for coro_rpc_connector") + FetchContent_Declare(pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.13.6 + GIT_SHALLOW TRUE + ) + FetchContent_MakeAvailable(pybind11) + endif() +endif() + +add_library(coro_rpc_connector OBJECT ${CORO_RPC_SOURCES}) + +set(_PYBIND11_INC_DIRS "") +if(TARGET pybind11::headers) + get_target_property(_PYBIND11_INC_DIRS pybind11::headers INTERFACE_INCLUDE_DIRECTORIES) +elseif(TARGET pybind11::pybind11) + get_target_property(_PYBIND11_INC_DIRS pybind11::pybind11 INTERFACE_INCLUDE_DIRECTORIES) +elseif(TARGET pybind11::pybind11_headers) + get_target_property(_PYBIND11_INC_DIRS pybind11::pybind11_headers INTERFACE_INCLUDE_DIRECTORIES) +endif() +if(_PYBIND11_INC_DIRS) + target_include_directories(coro_rpc_connector PRIVATE ${_PYBIND11_INC_DIRS}) +endif() + +target_link_libraries(coro_rpc_connector + PRIVATE + JsonCpp::JsonCpp + yalantinglibs::yalantinglibs + glog::glog + pthread + ${Python3_LIBRARIES} +) + +target_include_directories(coro_rpc_connector + PRIVATE + ${Python3_INCLUDE_DIRS} +) + +# Add pybind11 headers/include paths if available +if(TARGET pybind11::headers) + target_link_libraries(coro_rpc_connector PRIVATE pybind11::headers) +elseif(TARGET pybind11::pybind11) + target_link_libraries(coro_rpc_connector PRIVATE pybind11::pybind11) +else() + # Fallback: try to read include dirs from known pybind11 targets if present + if(TARGET pybind11::pybind11_headers) + get_target_property(_PYBIND_INCLUDES pybind11::pybind11_headers INTERFACE_INCLUDE_DIRECTORIES) + if(_PYBIND_INCLUDES) + target_include_directories(coro_rpc_connector PRIVATE ${_PYBIND_INCLUDES}) + endif() + endif() +endif() \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_communicator.cpp b/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_communicator.cpp new file mode 100644 index 000000000..fad60bed7 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_communicator.cpp @@ -0,0 +1,367 @@ +#include "transport/coro_rpc_connector/cororpc_communicator.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "async_simple/coro/SyncAwait.h" + +namespace py = pybind11; + +namespace mooncake { + +class py_rpc_context { + public: + void response_msg(py::buffer msg, py::handle done) { + py::buffer_info info = msg.request(); + const char* data = static_cast(info.ptr); + context_.get_context_info()->set_response_attachment( + std::string_view(data, info.size)); + done.inc_ref(); + context_.get_context_info()->set_complete_handler( + [done](const std::error_code& ec, std::size_t) { + py::gil_scoped_acquire acquire; + done(!ec); + done.dec_ref(); + }); + context_.response_msg(); + } + + coro_rpc::context context_; +}; + +CoroRPCCommunicator::CoroRPCCommunicator() : impl_(std::make_shared()) {} + +CoroRPCCommunicator::~CoroRPCCommunicator() { stopServer(); } + +void CoroRPCCommunicator::setDataReceiveCallback( + std::function callback) { + LOG(INFO) << "Setting data receive callback..."; + impl_->data_receive_callback = callback; + LOG(INFO) << "Data receive callback set successfully"; +} + +bool CoroRPCCommunicator::initialize(const Config& config) { + impl_->config = config; + easylog::set_min_severity(easylog::Severity::WARNING); // Set log level + // to WARNING + + // Initialize client pools with proper configuration + coro_io::client_pool::pool_config pool_conf{}; + const char* value = std::getenv("MC_RPC_PROTOCOL"); + if (value && std::string_view(value) == "rdma") { + pool_conf.client_config.socket_config = + coro_io::ib_socket_t::config_t{}; + } + client_pools_ = + std::make_shared>( + pool_conf); + + LOG(INFO) << "create coro_rpc_client_pool with " << config.pool_size + << " threads"; + if (!config.listen_address.empty()) { + LOG(INFO) << "Initializing server on " << config.listen_address; + + impl_->server_ = std::make_unique( + config.thread_count, config.listen_address, + std::chrono::seconds(config.timeout_seconds)); + + if (value && std::string_view(value) == "rdma") { + if (impl_->server_) { + try { + impl_->server_->init_ibv(); + LOG(INFO) << "RDMA initialized successfully"; + } catch (const std::exception& e) { + LOG(ERROR) << "RDMA initialization failed: " << e.what(); + LOG(WARNING) << "Falling back to TCP mode"; + // Continue without RDMA - the server will use TCP + } catch (...) { + LOG(ERROR) + << "RDMA initialization failed with unknown error"; + LOG(WARNING) << "Falling back to TCP mode"; + // Continue without RDMA - the server will use TCP + } + } else { + LOG(ERROR) << "Server pointer is null, cannot initialize RDMA"; + LOG(WARNING) << "Falling back to TCP mode"; + } + } + + impl_->server_->register_handler< + &CoroRPCCommunicator::Impl::handleDataTransfer, + &CoroRPCCommunicator::Impl::handleTensorTransfer>(impl_.get()); + } + LOG(INFO) << "Environment variable MC_RPC_PROTOCOL is set to " + << (value ? value : "not set"); + if (value && std::string_view(value) == "rdma") { + LOG(INFO) << "Using RDMA transport for RPC communication"; + } else { + LOG(INFO) << "Using TCP transport for RPC communication"; + } + + LOG(INFO) << "Communicator initialized with client pool support"; + return true; +} + +bool CoroRPCCommunicator::startServerImpl(bool is_async) { + if (is_async) { + return this->startServerAsync(); + } else { + return this->startServer(); + } +} + +bool CoroRPCCommunicator::startServer() { + if (!impl_->server_ || impl_->config.listen_address.empty()) return false; + + try { + auto ec = impl_->server_->start(); + if (ec.val() == 0) { + impl_->is_server_started = true; + LOG(INFO) << "Server started on " << impl_->config.listen_address; + return true; + } else { + LOG(ERROR) << "Failed to start server: " << ec.message(); + return false; + } + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to start server: " << e.what(); + return false; + } +} + +bool CoroRPCCommunicator::startServerAsync() { + if (!impl_->server_ || impl_->config.listen_address.empty()) return false; + + try { + auto ec = impl_->server_->async_start(); + if (!ec.hasResult()) { + impl_->is_server_started = true; + LOG(INFO) << "Server started asynchronously on " + << impl_->config.listen_address; + return true; + } else { + LOG(ERROR) << "Failed to start server asynchronously"; + return false; + } + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to start server asynchronously: " << e.what(); + return false; + } +} + +void CoroRPCCommunicator::stopServer() { + if (impl_->is_server_started) { + impl_->is_server_started = false; + LOG(INFO) << "Server stopped"; + } +} + +int CoroRPCCommunicator::sendData(const std::string& target_address, + const void* data, size_t data_size) { + auto result = async_simple::coro::syncAwait( + sendDataAsync(target_address, data, data_size)); + return result.code; +} + +async_simple::coro::Lazy CoroRPCCommunicator::sendDataAsync( + const std::string& target_address, const void* data, size_t data_size) { + std::string_view data_view(static_cast(data), data_size); + + // For large data, use attachment to avoid copying + const size_t ATTACHMENT_THRESHOLD = 1024; // Use attachment for data > 1KB + + auto rpc_result = co_await client_pools_->send_request( + target_address, + [data_view, data_size](coro_rpc::coro_rpc_client& client) + -> async_simple::coro::Lazy { + if (data_size > ATTACHMENT_THRESHOLD) { + // Use attachment for large data - zero copy + client.set_req_attachment(data_view); + // Send empty data parameter, actual data in attachment + auto result = + co_await client + .call<&CoroRPCCommunicator::Impl::handleDataTransfer>( + std::string_view{}); + if (!result.has_value()) { + LOG(ERROR) << "RPC call failed: " << result.error().msg; + } + } else { + // Use regular parameter for small data + auto result = + co_await client + .call<&CoroRPCCommunicator::Impl::handleDataTransfer>( + data_view); + if (!result.has_value()) { + LOG(ERROR) << "RPC call failed: " << result.error().msg; + } + } + }); + + if (!rpc_result.has_value()) { + LOG(INFO) << "RPC send request failed"; + co_return result{-1, "RPC call failed"}; + } + result res; + res.code = 0; + co_return res; +} + +int CoroRPCCommunicator::sendTensor(const std::string& target_address, + const pybind11::object& tensor) { + // Convert pybind11::object to TensorInfo + TensorInfo tensor_info; + // TODO: Extract tensor information from pybind11::object + auto result = async_simple::coro::syncAwait( + sendTensorAsync(target_address, tensor_info)); + return result; +} + +async_simple::coro::Lazy CoroRPCCommunicator::sendTensorAsync( + const std::string& target_address, const TensorInfo& tensor) { + auto rpc_result = co_await client_pools_->send_request( + target_address, + [&tensor](coro_rpc::coro_rpc_client& client) + -> async_simple::coro::Lazy { + client.set_req_attachment( + std::string_view((char*)tensor.data_ptr, tensor.total_bytes)); + + auto result = + co_await client + .call<&CoroRPCCommunicator::Impl::handleTensorTransfer>(); + + if (!result.has_value()) { + LOG(ERROR) << "Tensor RPC call failed: " << result.error().msg; + } + }); + if (!rpc_result.has_value()) { + LOG(INFO) << "Tensor RPC send request failed"; + co_return -1; + } + co_return 0; +} + +int CoroRPCCommunicator::receiveData(const std::string& source_address, + void* buffer, size_t buffer_size, + int timeout_ms) { + auto result = async_simple::coro::syncAwait( + receiveDataAsync(source_address, timeout_ms)); + return 0; +} + +async_simple::coro::Lazy CoroRPCCommunicator::receiveDataAsync( + const std::string& source_address, int timeout_ms) { + // For attachment-based data reception, we should use a different approach + // This method is typically called from the handler when data is received + // The actual data reception is handled by the registered handlers + co_return std::string(); +} // Data reception is handled via context and attachment in handlers + +void CoroRPCCommunicator::Impl::handleDataTransfer( + coro_rpc::context context, std::string_view data) { + // Check if there's an attachment for large data + auto ctx_info = context.get_context_info(); + auto attachment = ctx_info->get_request_attachment(); + + LOG(INFO) << "Handling data transfer - Data: " << data.size() + << " bytes, Attachment: " << attachment.size() << " bytes"; + // Call the data receive callback if set + if (data_receive_callback) { + LOG(INFO) << "Calling data receive callback..."; + std::string_view source_address = + "unknown"; // Could extract from context if needed + + // Use attachment if available (for large data), otherwise use data + // parameter + if (!attachment.empty()) { + // Use attachment data directly without copying - zero copy approach + std::string_view attachment_view = attachment; + data_receive_callback(source_address, attachment_view); + } else { + // For small data, use the regular data parameter + data_receive_callback(source_address, data); + } + } else { + LOG(INFO) << "No data receive callback set!"; + } + + // Echo back the attachment for response (zero-copy) + if (!attachment.empty()) { + ctx_info->set_response_attachment(std::string_view("ok")); + } + + context.response_msg(); +} + +void CoroRPCCommunicator::Impl::handleTensorTransfer( + coro_rpc::context context) { + auto ctx_info = context.get_context_info(); + auto attachment = ctx_info->get_request_attachment(); + + LOG(INFO) << "Handling tensor transfer: " << attachment.size() << " bytes"; + + // Call the data receive callback if set (tensor data is received via + // attachment) + if (data_receive_callback) { + LOG(INFO) << "Calling data receive callback for tensor..."; + std::string_view source_address = + "unknown"; // Could extract from context if needed + + // Pass the attachment data to the callback + data_receive_callback(source_address, attachment); + } else { + LOG(INFO) << "No data receive callback set for tensor!"; + } + + ctx_info->set_response_attachment(attachment); + context.response_msg(); +} + +void CoroRPCCommunicator::Impl::handleDataTransferWithAttachment( + coro_rpc::context context, std::string_view data) { + py_rpc_context t{}; + t.context_ = std::move(context); + py::gil_scoped_acquire acquire; + // auto ctx_info = context.get_context_info(); + // auto attachment = ctx_info->get_request_attachment(); + auto view = + py::memoryview::from_buffer(data.data(), {data.size()}, {sizeof(char)}); + + // LOG(INFO) << "Handling data transfer with attachment - Data: " + // << data.size() << " bytes, Attachment: " << attachment.size() + // << " bytes" << std::endl; + py_callback(std::move(t), view); + // context.response_msg(); +} + +void CoroRPCCommunicator::Impl::handleTensorTransferWithAttachment( + coro_rpc::context context) { + py_rpc_context t{}; + + // Get the attachment before moving the context + auto ctx_info = context.get_context_info(); + auto attachment = ctx_info->get_request_attachment(); + + t.context_ = std::move(context); + py::gil_scoped_acquire acquire; + + auto view = py::memoryview::from_buffer( + attachment.data(), {attachment.size()}, {sizeof(int8_t)}); + + py_callback(std::move(t), view); + // auto ctx_info = context.get_context_info(); + // auto attachment = ctx_info->get_request_attachment(); + + // LOG(INFO) << "Handling tensor transfer with attachment: " + // << attachment.size() << " bytes" << std::endl; + + // ctx_info->set_response_attachment(attachment); + // context.response_msg(); +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_interface.cpp b/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_interface.cpp new file mode 100644 index 000000000..9003d5113 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/coro_rpc_connector/cororpc_interface.cpp @@ -0,0 +1,501 @@ +#include "transport/coro_rpc_connector/cororpc_interface.h" +#include "transport/coro_rpc_connector/cororpc_communicator.h" +#include +#include +#include +#include +#include +#include +#include "async_simple/coro/SyncAwait.h" + +namespace mooncake { + +// Implementation class +class CoroRPCInterface::Impl { + public: + std::unique_ptr communicator; + pybind11::function data_receive_callback; + pybind11::function tensor_receive_callback; +}; + +// Constructor +CoroRPCInterface::CoroRPCInterface() : impl_(std::make_unique()) {} + +// Destructor +CoroRPCInterface::~CoroRPCInterface() = default; + +// Initialize +bool CoroRPCInterface::initialize(const std::string& local_address, + size_t thread_count, size_t timeout_seconds, + size_t pool_size) { + Config config; + config.listen_address = local_address; + config.thread_count = thread_count; + config.timeout_seconds = timeout_seconds; + config.pool_size = pool_size; + + impl_->communicator = std::make_unique(); + return impl_->communicator->initialize(config); +} + +// Convenience method for client initialization +bool CoroRPCInterface::initializeClient(size_t pool_size, + size_t timeout_seconds) { + return initialize("", 0, timeout_seconds, pool_size); +} + +// Convenience method for server initialization +bool CoroRPCInterface::initializeServer(const std::string& listen_address, + size_t thread_count, + size_t timeout_seconds) { + return initialize(listen_address, thread_count, timeout_seconds, 4); +} + +bool CoroRPCInterface::startServer() { + if (!impl_->communicator) return false; + return impl_->communicator->startServer(); +} + +bool CoroRPCInterface::startServerAsync() { + if (!impl_->communicator) return false; + return impl_->communicator->startServerAsync(); +} + +bool CoroRPCInterface::startServerImpl(bool is_async) { + if (!impl_->communicator) return false; + return impl_->communicator->startServerImpl(is_async); +} + +void CoroRPCInterface::stopServer() { + if (impl_->communicator) { + impl_->communicator->stopServer(); + } +} + +int CoroRPCInterface::sendData(const std::string& target_address, + pybind11::handle data) { + if (!impl_->communicator) return -1; + + pybind11::gil_scoped_acquire acquire; + + // Extract data from handle directly + std::string_view data_view; + try { + // Try to get direct string view from bytes object + pybind11::bytes data_bytes = + pybind11::reinterpret_borrow(data); + data_view = data_bytes; + } catch (...) { + // Fallback: convert to bytes and then get view + pybind11::bytes data_bytes = pybind11::cast(data); + data_view = data_bytes; + } + + pybind11::gil_scoped_release release; + return impl_->communicator->sendData(target_address, data_view.data(), + data_view.size()); +} + +pybind11::object CoroRPCInterface::sendDataAsync(std::string& target_address, + pybind11::handle data, + pybind11::handle loop) { + pybind11::gil_scoped_acquire acquire; + + auto future_module = pybind11::module_::import("asyncio"); + auto future_obj = future_module.attr("Future")(); + + if (!impl_->communicator) { + future_obj.attr("set_exception")(pybind11::make_tuple( + pybind11::str("Communicator not initialized"))); + return future_obj; + } + + auto communicator = impl_->communicator.get(); + auto target_addr = std::move(target_address); + + // Extract data from handle directly without creating intermediate bytes + // object + std::string_view data_view; + std::string data_str; // Only used if we can't get direct view + + try { + // Try to get direct string view from bytes object + pybind11::bytes data_bytes = + pybind11::reinterpret_borrow(data); + data_view = data_bytes; + // Store a copy for lambda capture since string_view might not be valid + // after GIL release + data_str = std::string(data_view); + } catch (...) { + // Fallback: convert to bytes and then to string + pybind11::bytes data_bytes = pybind11::cast(data); + data_str = data_bytes; + } + + // Release GIL before starting coroutine + pybind11::gil_scoped_release release; + + auto coro_lambda = [communicator, target_addr, data_str, future_obj, + loop]() -> async_simple::coro::Lazy { + try { + auto result_struct = co_await communicator->sendDataAsync( + target_addr, data_str.data(), data_str.size()); + int result = result_struct.code; + + auto call_soon_threadsafe = [future_obj, loop, result]() { + pybind11::gil_scoped_acquire acquire; + if (result >= 0) { + future_obj.attr("set_result")(result); + } else { + future_obj.attr("set_exception")(pybind11::make_tuple( + pybind11::str("Send data failed"))); + } + }; + + auto callback = pybind11::cpp_function(call_soon_threadsafe); + loop.attr("call_soon_threadsafe")(callback); + } catch (const std::exception& e) { + auto call_soon_threadsafe = [future_obj, loop, e]() { + pybind11::gil_scoped_acquire acquire; + future_obj.attr("set_exception")( + pybind11::make_tuple(pybind11::str( + std::string("Send data error: ") + e.what()))); + }; + + auto callback = pybind11::cpp_function(call_soon_threadsafe); + loop.attr("call_soon_threadsafe")(callback); + } + }; + + auto lazy = coro_lambda(); + lazy.start([](auto&& result) { + if (result.hasError()) { + std::cerr << "Coroutine completed with error"; + } + }); + + return future_obj; +} + +int CoroRPCInterface::sendTensor(const std::string& target_address, + pybind11::handle tensor) { + if (!impl_->communicator) return -1; + + try { + TensorInfo tensor_info; + + { + pybind11::gil_scoped_acquire acquire; + pybind11::object tensor_obj = + pybind11::reinterpret_borrow(tensor); + + // Validate tensor type + if (!(tensor_obj.attr("__class__") + .attr("__name__") + .cast() + .find("Tensor") != std::string::npos)) { + std::cerr << "Input is not a tensor"; + return -1; + } + + // Extract tensor properties - zero copy, just get pointers and + // metadata + uintptr_t data_ptr = + tensor_obj.attr("data_ptr")().cast(); + size_t numel = tensor_obj.attr("numel")().cast(); + size_t element_size = + tensor_obj.attr("element_size")().cast(); + size_t tensor_size = numel * element_size; + + // Get tensor shape + pybind11::object shape_obj = tensor_obj.attr("shape"); + pybind11::tuple shape_tuple = + pybind11::cast(shape_obj); + std::vector shape; + for (size_t i = 0; i < shape_tuple.size(); i++) { + shape.push_back(shape_tuple[i].cast()); + } + + // Get tensor dtype string + pybind11::object dtype_obj = tensor_obj.attr("dtype"); + std::string dtype = dtype_obj.attr("__str__")().cast(); + + // Fill TensorInfo structure (no data copying, just metadata) + tensor_info.data_ptr = reinterpret_cast(data_ptr); + tensor_info.total_bytes = tensor_size; + tensor_info.shape = std::move(shape); + tensor_info.dtype = std::move(dtype); + + std::cout << "Sending tensor with shape: ["; + for (size_t i = 0; i < tensor_info.shape.size(); i++) { + std::cout << tensor_info.shape[i]; + if (i < tensor_info.shape.size() - 1) std::cout << ", "; + } + std::cout << "] and dtype: " << tensor_info.dtype + << ", tensor size: " << tensor_info.total_bytes + << " bytes"; + } + + // Use the async version which supports zero-copy via attachments + pybind11::gil_scoped_release release; + auto result = async_simple::coro::syncAwait( + impl_->communicator->sendTensorAsync(target_address, tensor_info)); + return result; + + } catch (const std::exception& e) { + std::cerr << "Send tensor error: " << e.what(); + return -1; + } +} + +pybind11::object CoroRPCInterface::sendTensorAsync(std::string& target_address, + pybind11::handle tensor, + pybind11::handle loop) { + pybind11::gil_scoped_acquire acquire; + + auto future_module = pybind11::module_::import("asyncio"); + auto future_obj = future_module.attr("Future")(); + + if (!impl_->communicator) { + future_obj.attr("set_exception")(pybind11::make_tuple( + pybind11::str("Communicator not initialized"))); + return future_obj; + } + + auto communicator = impl_->communicator.get(); + std::string target_addr = std::move(target_address); + + // Extract tensor info + pybind11::object tensor_obj = + pybind11::reinterpret_borrow(tensor); + uintptr_t data_ptr = tensor_obj.attr("data_ptr")().cast(); + size_t numel = tensor_obj.attr("numel")().cast(); + size_t element_size = tensor_obj.attr("element_size")().cast(); + size_t tensor_size = numel * element_size; + + // Get tensor shape and dtype + pybind11::object shape_obj = tensor_obj.attr("shape"); + pybind11::tuple shape_tuple = pybind11::cast(shape_obj); + std::vector shape; + for (size_t i = 0; i < shape_tuple.size(); i++) { + shape.push_back(shape_tuple[i].cast()); + } + + pybind11::object dtype_obj = tensor_obj.attr("dtype"); + std::string dtype = dtype_obj.attr("__str__")().cast(); + + // Create TensorInfo on stack (no need for shared_ptr) + TensorInfo tensor_info; + tensor_info.data_ptr = reinterpret_cast(data_ptr); + tensor_info.total_bytes = tensor_size; + tensor_info.shape = std::move(shape); + tensor_info.dtype = std::move(dtype); + + pybind11::gil_scoped_release release; + + // Schedule coroutine to run asynchronously + auto coro_lambda = [communicator, target_addr, tensor_info, future_obj, + loop]() -> async_simple::coro::Lazy { + try { + auto result = co_await communicator->sendTensorAsync(target_addr, + tensor_info); + + auto call_soon_threadsafe = [future_obj, loop, result]() { + pybind11::gil_scoped_acquire acquire; + if (result >= 0) { + future_obj.attr("set_result")(result); + } else { + future_obj.attr("set_exception")(pybind11::make_tuple( + pybind11::str("Send tensor failed"))); + } + }; + + auto callback = pybind11::cpp_function(call_soon_threadsafe); + loop.attr("call_soon_threadsafe")(callback); + } catch (const std::exception& e) { + auto call_soon_threadsafe = [future_obj, loop, e]() { + pybind11::gil_scoped_acquire acquire; + future_obj.attr("set_exception")( + pybind11::make_tuple(pybind11::str( + std::string("Send tensor error: ") + e.what()))); + }; + + auto callback = pybind11::cpp_function(call_soon_threadsafe); + loop.attr("call_soon_threadsafe")(callback); + } + }; + + // Start the coroutine in a detached manner (fire and forget) + auto lazy = coro_lambda(); + lazy.start([](auto&& result) { + // This callback will be called when the coroutine completes + // We don't need to do anything here since the result is handled in the + // coroutine itself + if (result.hasError()) { + // Log error if needed + std::cerr << "Tensor coroutine completed with error"; + } + }); + + return future_obj; +} + +void CoroRPCInterface::setDataReceiveCallback(pybind11::function callback) { + pybind11::gil_scoped_acquire acquire; + impl_->data_receive_callback = callback; + if (impl_->communicator) { + auto interface_ptr = this; + impl_->communicator->setDataReceiveCallback( + [interface_ptr](std::string_view source, std::string_view data) { + interface_ptr->handleIncomingData(source, data); + }); + } +} + +void CoroRPCInterface::setTensorReceiveCallback(pybind11::function callback) { + pybind11::gil_scoped_acquire acquire; + impl_->tensor_receive_callback = callback; + + // Note: Tensor data is received through the regular data callback + // The handleIncomingData function will detect tensor data and route it + // to handleIncomingTensor automatically +} + +void CoroRPCInterface::handleIncomingData(std::string_view source, + std::string_view data) { + std::cout << "CoroRPCInterface::handleIncomingData called with " + << data.size() << " bytes"; + + // C++ tensor rebuilding + if (data.size() >= 72) { // 72 bytes is our metadata size + // Read the first few bytes to check if it looks like tensor metadata + const uint32_t* header = reinterpret_cast(data.data()); + uint32_t dtype = header[0]; + uint32_t ndim = header[1]; + + std::cout << "Checking tensor metadata: dtype=" << dtype + << ", ndim=" << ndim; + + // Basic validation: check if dtype and ndim are in reasonable ranges + if (dtype > 0 && dtype <= 9 && ndim <= 4) { + std::cout + << "Data recognized as tensor, calling handleIncomingTensor"; + + // This looks like tensor data, handle it as such + std::vector shape; + const int64_t* shape_data = + reinterpret_cast(data.data() + 8); + for (int i = 0; i < static_cast(ndim); i++) { + if (shape_data[i] > 0) { + shape.push_back(static_cast(shape_data[i])); + } + } + + // Get dtype name based on dtype ID + std::string_view dtype_name; + switch (dtype) { + case 1: + dtype_name = "float16"; + break; + case 2: + dtype_name = "float32"; + break; + case 3: + dtype_name = "float64"; + break; + case 4: + dtype_name = "int8"; + break; + case 5: + dtype_name = "int16"; + break; + case 6: + dtype_name = "int32"; + break; + case 7: + dtype_name = "int64"; + break; + case 8: + dtype_name = "uint8"; + break; + case 9: + dtype_name = "bool"; + break; + default: + dtype_name = "unknown"; + break; + } + + // Call tensor handler instead of data handler + handleIncomingTensor(source, data, shape, dtype_name); + return; + } + } + + // Handle as regular data if not tensor data + if (!impl_->data_receive_callback) return; + + try { + pybind11::gil_scoped_acquire acquire; + pybind11::dict received; + received["source"] = + std::string(source); // Convert to string for Python + received["data"] = pybind11::bytes( + std::string(data)); // Convert to string for pybind11::bytes + + impl_->data_receive_callback(received); + } catch (const std::exception& e) { + LOG(ERROR) << "Error in data receive callback: " << e.what(); + } +} + +void CoroRPCInterface::handleIncomingTensor(std::string_view source, + std::string_view data, + const std::vector& shape, + std::string_view dtype) { + std::cout << "CoroRPCInterface::handleIncomingTensor called"; + std::cout << " source: " << source; + std::cout << " data size: " << data.size(); + std::cout << " dtype: " << dtype; + std::cout << " shape size: " << shape.size(); + + if (!impl_->tensor_receive_callback) { + std::cout << "No tensor receive callback set!"; + return; + } + + std::cout << "Calling Python tensor receive callback..."; + + try { + pybind11::gil_scoped_acquire acquire; + + ReceivedTensor received; + received.source_address = std::string(source); + received.data = std::string(data); + received.shape = shape; + received.dtype = std::string(dtype); + + impl_->tensor_receive_callback(received); + } catch (const std::exception& e) { + std::cerr << "Error in tensor receive callback: " << e.what(); + } +} + +// Factory functions for creating RPC client and server +std::unique_ptr createRPCClient(uint64_t local_rank, + uint64_t world_size) { + auto client = std::make_unique(); + // Initialize client with default settings + client->initialize("", 0, 30, 10); + return client; +} + +std::unique_ptr createRPCServer(uint64_t local_rank, + uint64_t world_size) { + auto server = std::make_unique(); + // Initialize server with default settings + server->initialize("0.0.0.0:8080", 0, 30, 10); + return server; +} + +} // namespace mooncake diff --git a/mooncake-transfer-engine/tests/communicator_bandwidth_test.py b/mooncake-transfer-engine/tests/communicator_bandwidth_test.py new file mode 100644 index 000000000..ccb4f2198 --- /dev/null +++ b/mooncake-transfer-engine/tests/communicator_bandwidth_test.py @@ -0,0 +1,132 @@ +import torch +import numpy as np +import time +import sys +import threading +import struct +import argparse +from typing import List, Tuple, Dict, Any + +import mooncake.engine as engine + +class AtomicCounter: + def __init__(self, initial=0): + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + with self._lock: + self._value -= num + return self._value + + def get(self): + with self._lock: + r = self._value + self._value = 0 + return r + +counter = AtomicCounter() + +# Global variable to store data size +data_size = 1024 * 1024 # Default 1MB +test_data = None + +def print_qps(): + while True: + time.sleep(1) + val = counter.get() + if val == 0: + continue + print("bandwidth:", 8 * val * data_size / (1000 * 1000 * 1000), "Gb/s") + +def send_data(client, target_url): + while True: + result = client.send_data(target_url, test_data) + counter.inc() + +def run_server(bind_url, data_size_mb=1): + """Run server mode""" + global data_size, test_data + data_size = data_size_mb * 1024 * 1024 + test_data = b'\x00' * data_size + + print(f"Starting server on {bind_url} with {data_size_mb}MB data packets") + + CoroRPCInterface = engine.CoroRPCInterface + server = CoroRPCInterface() + server.initialize_server(bind_url, thread_count=8) + server.start_server_async() #start the server asynchronously + + # Start QPS statistics thread + thread = threading.Thread(target=print_qps) + thread.daemon = True + thread.start() + + print(f"Server started on {bind_url}, press Ctrl+C to stop") + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nServer stopping...") + server.stop_server() + +def run_client(target_url, num_threads=8, data_size_mb=1): + """Run client mode""" + global data_size, test_data + data_size = data_size_mb * 1024 * 1024 + test_data = b'\x00' * data_size + + print(f"Starting client, connecting to {target_url} with {num_threads} threads, {data_size_mb}MB data packets") + + CoroRPCInterface = engine.CoroRPCInterface + client = CoroRPCInterface() + client.initialize_client(pool_size=100) + + # Start QPS statistics thread + qps_thread = threading.Thread(target=print_qps) + qps_thread.daemon = True + qps_thread.start() + + # Start multiple sending threads + threads = [] + for i in range(num_threads): + thread = threading.Thread(target=send_data, args=(client, target_url)) + thread.daemon = True + thread.start() + threads.append(thread) + + print(f"Client started with {num_threads} threads, press Ctrl+C to stop") + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nClient stopping...") + +def main(): + parser = argparse.ArgumentParser(description='Mooncake Communication Bandwidth Test Tool') + parser.add_argument('mode', choices=['server', 'client'], + help='Run mode: server or client') + parser.add_argument('--url', default='127.0.0.1:9004', + help='URL address (default: 127.0.0.1:9004)') + parser.add_argument('--threads', type=int, default=8, + help='Number of client threads (default: 8)') + parser.add_argument('--data-size', type=int, default=1, + help='Data packet size in MB (default: 1)') + + args = parser.parse_args() + + if args.mode == 'server': + # Server mode, URL as bind address + bind_url = f"0.0.0.0:{args.url.split(':')[-1]}" if ':' in args.url else f"0.0.0.0:{args.url}" + run_server(bind_url, args.data_size) + else: + # Client mode, URL as target address + run_client(args.url, args.threads, args.data_size) + +if __name__ == "__main__": + main()