diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index bdf454c68..777a804a9 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -315,6 +315,61 @@ int DistributedObjectStore::allocateSlices( return 0; } +int DistributedObjectStore::allocateBatchedSlices( + const std::vector &keys, + const std::vector> &values, + std::unordered_map> + &batched_slices) { + for (size_t i = 0; i < keys.size(); ++i) { + uint64_t offset = 0; + const auto &value = values[i]; + std::vector slices; + while (offset < value.size()) { + auto chunk_size = std::min(value.size() - offset, kMaxSliceSize); + auto ptr = client_buffer_allocator_->allocate(chunk_size); + if (!ptr) { + return 1; + } + memcpy(ptr, value.data() + offset, chunk_size); + slices.emplace_back(Slice{ptr, chunk_size}); + offset += chunk_size; + } + batched_slices.emplace(keys[i], std::move(slices)); + } + return 0; +} + +int DistributedObjectStore::allocateBatchedSlices( + const std::vector &keys, + std::unordered_map> + &batched_slices, + const mooncake::Client::BatchObjectInfo &batched_object_info, + std::unordered_map &str_length_map) { + if (batched_object_info.batch_replica_list.empty()) return -1; + for (const auto &key : keys) { + auto object_info_it = batched_object_info.batch_replica_list.find(key); + if (object_info_it == batched_object_info.batch_replica_list.end()) { + LOG(ERROR) << "Key not found: " << key; + return 1; + } + // Get first replica + auto &replica = object_info_it->second[0]; + uint64_t length = 0; + for (auto &handle : replica.buffer_descriptors) { + auto chunk_size = handle.size_; + assert(chunk_size <= kMaxSliceSize); + auto ptr = client_buffer_allocator_->allocate(chunk_size); + if (!ptr) { + return 1; + } + batched_slices[key].emplace_back(Slice{ptr, chunk_size}); + length += chunk_size; + } + str_length_map.emplace(key, length); + } + return 0; +} + char *DistributedObjectStore::exportSlices( const std::vector &slices, uint64_t length) { char *buf = new char[length + 1]; @@ -377,6 +432,39 @@ int DistributedObjectStore::put(const std::string &key, return 0; } +int DistributedObjectStore::put_batch( + const std::vector &keys, + const std::vector> &values) { + if (!client_) { + LOG(ERROR) << "Client is not initialized"; + return 1; + } + if (keys.size() != values.size()) { + LOG(ERROR) << "Key and value size mismatch"; + } + std::unordered_map> + batched_slices; + int ret = allocateBatchedSlices(keys, values, batched_slices); + if (ret) { + LOG(ERROR) << "Failed to allocate slices for put_batch operation"; + return ret; + } + + ReplicateConfig config; + config.replica_num = 1; + ErrorCode error_code = client_->BatchPut(keys, batched_slices, config); + if (error_code != ErrorCode::OK) { + LOG(ERROR) << "BatchPut operation failed with error: " + << toString(error_code); + return toInt(error_code); + } + + for (auto &slice : batched_slices) { + freeSlices(slice.second); + } + return 0; +} + int DistributedObjectStore::put_parts( const std::string &key, std::vector> values) { if (!client_) { @@ -466,6 +554,76 @@ pybind11::bytes DistributedObjectStore::get(const std::string &key) { return result; } +std::vector DistributedObjectStore::get_batch( + const std::vector &keys) { + const auto kNullString = pybind11::bytes("\0", 0); + if (!client_) { + LOG(ERROR) << "Client is not initialized"; + return {kNullString}; + } + std::unordered_set seen; + for (const auto &key : keys) { + if (!seen.insert(key).second) { + LOG(ERROR) << "Duplicate key not supported for Batch API, key: " + << key; + return {kNullString}; + } + } + + std::vector results; + mooncake::Client::BatchObjectInfo batched_object_info; + std::unordered_map> + batched_slices; + std::unordered_map str_length_map; + { + py::gil_scoped_release release_gil; + ErrorCode error_code = client_->BatchQuery(keys, batched_object_info); + if (error_code != ErrorCode::OK) { + py::gil_scoped_acquire acquire_gil; + return {kNullString}; + } else { + int ret = allocateBatchedSlices( + keys, batched_slices, batched_object_info, str_length_map); + if (ret) { + py::gil_scoped_acquire acquire_gil; + return {kNullString}; + } + error_code = + client_->BatchGet(keys, batched_object_info, batched_slices); + if (error_code != ErrorCode::OK) { + py::gil_scoped_acquire acquire_gil; + return {kNullString}; + } + } + for (const auto &key : keys) { + if (batched_slices[key].size() == 1 && + batched_slices[key][0].size == str_length_map[key]) { + results.push_back(pybind11::bytes( + static_cast(batched_slices[key][0].ptr), + str_length_map[key])); + } else { + char *exported_str_ptr = + exportSlices(batched_slices[key], str_length_map[key]); + if (!exported_str_ptr) { + return {kNullString}; + } else { + results.push_back( + pybind11::bytes(exported_str_ptr, str_length_map[key])); + delete[] exported_str_ptr; + } + } + } + if (results.size() != keys.size()) { + LOG(ERROR) << "Results size does not match keys size"; + return {kNullString}; + } + for (auto &slice : batched_slices) { + freeSlices(slice.second); + } + return results; + } +} + int DistributedObjectStore::remove(const std::string &key) { if (!client_) { LOG(ERROR) << "Client is not initialized"; @@ -809,6 +967,7 @@ PYBIND11_MODULE(store, m) { .def("setup", &DistributedObjectStore::setup) .def("init_all", &DistributedObjectStore::initAll) .def("get", &DistributedObjectStore::get) + .def("get_batch", &DistributedObjectStore::get_batch) .def("get_buffer", &DistributedObjectStore::get_buffer, py::call_guard(), py::return_value_policy::take_ownership) @@ -867,27 +1026,49 @@ PYBIND11_MODULE(store, m) { static_cast(info.ptr), static_cast(info.size))); }) - .def("put_parts", [](DistributedObjectStore &self, - const std::string &key, py::args parts) { - // 1) Python buffer → span - std::vector infos; - std::vector> spans; - infos.reserve(parts.size()); - spans.reserve(parts.size()); - - for (auto &obj : parts) { - py::buffer buf = py::reinterpret_borrow(obj); - infos.emplace_back(buf.request(false)); - const auto &info = infos.back(); - if (info.ndim != 1 || info.itemsize != 1) - throw std::runtime_error("parts must be 1-D bytes-like"); - - spans.emplace_back(static_cast(info.ptr), - static_cast(info.size)); - } - - // 2) Call C++ function - py::gil_scoped_release unlock; - return self.put_parts(key, spans); - }); + .def("put_parts", + [](DistributedObjectStore &self, const std::string &key, + py::args parts) { + // 1) Python buffer → span + std::vector infos; + std::vector> spans; + infos.reserve(parts.size()); + spans.reserve(parts.size()); + + for (auto &obj : parts) { + py::buffer buf = py::reinterpret_borrow(obj); + infos.emplace_back(buf.request(false)); + const auto &info = infos.back(); + if (info.ndim != 1 || info.itemsize != 1) + throw std::runtime_error( + "parts must be 1-D bytes-like"); + + spans.emplace_back(static_cast(info.ptr), + static_cast(info.size)); + } + + // 2) Call C++ function + py::gil_scoped_release unlock; + return self.put_parts(key, spans); + }) + .def( + "put_batch", + [](DistributedObjectStore &self, + const std::vector &keys, + const std::vector &py_values) { + std::vector temp_values; + temp_values.reserve(py_values.size()); + for (const auto &value : py_values) { + temp_values.emplace_back(value.cast()); + } + + std::vector> spans; + spans.reserve(temp_values.size()); + for (const auto &s : temp_values) { + spans.emplace_back(s.data(), s.size()); + } + + return self.put_batch(keys, spans); + }, + py::arg("keys"), py::arg("values")); } diff --git a/mooncake-integration/store/store_py.h b/mooncake-integration/store/store_py.h index 7cf6d6fc0..d6b84a87c 100644 --- a/mooncake-integration/store/store_py.h +++ b/mooncake-integration/store/store_py.h @@ -141,8 +141,14 @@ class DistributedObjectStore { int put_parts(const std::string &key, std::vector> values); + int put_batch(const std::vector &keys, + const std::vector> &values); + pybind11::bytes get(const std::string &key); + std::vector get_batch( + const std::vector &keys); + /** * @brief Get a buffer containing the data for a key * @param key Key to get data for @@ -194,6 +200,19 @@ class DistributedObjectStore { int allocateSlicesPacked(std::vector &slices, const std::vector> &parts); + int allocateBatchedSlices( + const std::vector &keys, + std::unordered_map> + &batched_slices, + const mooncake::Client::BatchObjectInfo &batched_object_info, + std::unordered_map &str_length_map); + + int allocateBatchedSlices( + const std::vector &keys, + const std::vector> &values, + std::unordered_map> + &batched_slices); + char *exportSlices(const std::vector &slices, uint64_t length); diff --git a/mooncake-wheel/mooncake/transfer_engine/__init__.py b/mooncake-wheel/mooncake/transfer_engine/__init__.py new file mode 100644 index 000000000..ebc08cad8 --- /dev/null +++ b/mooncake-wheel/mooncake/transfer_engine/__init__.py @@ -0,0 +1,2 @@ +# Import for backward compatibility + diff --git a/mooncake-wheel/mooncake/transfer_engine/mooncake_inference.py b/mooncake-wheel/mooncake/transfer_engine/mooncake_inference.py new file mode 100644 index 000000000..adbb18f84 --- /dev/null +++ b/mooncake-wheel/mooncake/transfer_engine/mooncake_inference.py @@ -0,0 +1,31 @@ +import torch +import zmq + +from transfer_engine import MooncakeTransferEngine +from new_tensor import create_tensor + +class MooncakeInference: + def __init__(self, config: dict): + self.config = config + self._engine = MooncakeTransferEngine( + hostname = config["inference_ip"], + gpu_id = config["inference_gpu_id"], # Using GPU + ib_device = None # No specific IB device + ) + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REQ) + self._socket.connect(f"tcp://{config['training_ip']}:{config['port']}") + + def __del__(self): + self._socket.close() + self._context.destroy() + + def recv_tensor(self) -> dict: + self._socket.send_json({"session_id": self._engine.get_session_id()}) + ret = self._socket.recv_json() # ["name", "shape", "dtype"] + tensor = create_tensor(ret[1], ret[2], self.config["inference_gpu_id"]) + self._engine.register(tensor.data_ptr(), tensor.numel() * tensor.element_size()) + self._socket.send_json({"ptr": tensor.data_ptr()}) + self._socket.recv_json() + self._engine.deregister(tensor.data_ptr()) + return {ret[0]: tensor} \ No newline at end of file diff --git a/mooncake-wheel/mooncake/transfer_engine/mooncake_training.py b/mooncake-wheel/mooncake/transfer_engine/mooncake_training.py new file mode 100644 index 000000000..7f13d1988 --- /dev/null +++ b/mooncake-wheel/mooncake/transfer_engine/mooncake_training.py @@ -0,0 +1,55 @@ +import torch +import zmq +import time +from threading import Thread +from queue import Queue +from new_tensor import get_dtype_str + +from transfer_engine import MooncakeTransferEngine + +class MooncakeTraining(Thread): + def __init__(self, config: dict): + super().__init__() + self._stop = False + self.config = config + self._queue = Queue(maxsize=config["send_bulk"]) + self.engine = MooncakeTransferEngine( + hostname = config["training_ip"], + gpu_id = config["training_gpu_id"], # Using GPU + ib_device = None # No specific IB device + ) + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REP) + self._socket.bind(f"tcp://*:{config['port']}") + self._report = 0 + self.start() + + def __del__(self): + self._socket.close() + self._context.destroy() + + + def reg_tensor(self, name: str, tensor: torch.Tensor): + self._queue.put((name, tensor)) + + def run(self): + while True: + name, tensor = self._queue.get() + if name is None: + break + ret = self._socket.recv_json() # get req. + session_id = ret["session_id"] + self._socket.send_json([name, tensor.shape, get_dtype_str(tensor.dtype)]) + ptr, size = tensor.data_ptr(), tensor.numel() * tensor.element_size() + self.engine.register(ptr, size) + ret = self._socket.recv_json() + t1 = time.time() + self.engine.transfer_sync(session_id, ptr, ret["ptr"], size) + t2 = time.time() + self._report += t2 - t1 + self._socket.send_json(ret) + self.engine.deregister(ptr) + + def report(self): + return self._report + \ No newline at end of file diff --git a/mooncake-wheel/mooncake/transfer_engine/new_tensor.py b/mooncake-wheel/mooncake/transfer_engine/new_tensor.py new file mode 100644 index 000000000..09b15ffa2 --- /dev/null +++ b/mooncake-wheel/mooncake/transfer_engine/new_tensor.py @@ -0,0 +1,52 @@ +import torch + +dtype_map = { + 'bool': torch.bool, + + 'float16': torch.float16, + 'float32': torch.float32, + 'float64': torch.float64, + 'half': torch.half, + 'float': torch.float, + 'double': torch.double, + 'bfloat16': torch.bfloat16, + + 'int8': torch.int8, + 'int16': torch.int16, + 'int32': torch.int32, + 'int64': torch.int64, + 'uint8': torch.uint8, + 'uint16': torch.uint16, + 'uint32': torch.uint32, + 'uint64': torch.uint64, + 'long': torch.long, + 'int': torch.int, + 'short': torch.short, + + 'complex64': torch.complex64, + 'complex128': torch.complex128, + 'complex': torch.complex +} + +def shape_to_stride(shape: list)-> list: + strides = [] + product = 1 + for dim in reversed(shape): + strides.append(product) + product *= dim + return list(reversed(strides)) + +def get_child_tensor(tensor: torch.Tensor, shape: list, offset: int)-> torch.Tensor: + stride = shape_to_stride(shape) + return tensor.as_strided(shape, stride, storage_offset=offset) + +def get_dtype(dtype: str)-> torch.dtype: + return dtype_map[dtype] + +def get_dtype_str(dtype: torch.dtype)-> str: + for k, v in dtype_map.items(): + if v == dtype: + return k + +def create_tensor(shape: list, dtype: str, gpu_id=0): + return torch.empty(shape, dtype=dtype_map[dtype], device=f"cuda:{gpu_id}") diff --git a/mooncake-wheel/mooncake/transfer_engine/transfer_engine.py b/mooncake-wheel/mooncake/transfer_engine/transfer_engine.py new file mode 100644 index 000000000..eb78c01b2 --- /dev/null +++ b/mooncake-wheel/mooncake/transfer_engine/transfer_engine.py @@ -0,0 +1,89 @@ +import json +import logging +from dataclasses import dataclass +from typing import Optional + +logger = logging.getLogger(__name__) + + +class MooncakeTransferEngine: + + def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None): + try: + from mooncake.engine import TransferEngine + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run SGLang with MooncakeTransferEngine." + ) from e + + self.engine = TransferEngine() + self.hostname = hostname + self.gpu_id = gpu_id + self.ib_device = ib_device + + self.initialize( + hostname=self.hostname, + device_name=self.ib_device, + ) + self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" + + def register(self, ptr, length): + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + logger.error("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") + + def deregister(self, ptr): + ret_value = self.engine.unregister_memory(ptr) + if ret_value != 0: + logger.error("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") + + def initialize( + self, + hostname: str, + device_name: Optional[str], + ) -> None: + """Initialize the mooncake instance.""" + ret_value = self.engine.initialize( + hostname, + "P2PHANDSHAKE", + "rdma", + device_name if device_name is not None else "", + ) + if ret_value != 0: + logger.error("Mooncake Transfer Engine initialization failed.") + raise RuntimeError("Mooncake Transfer Engine initialization failed.") + + def transfer_sync( + self, session_id: str, buffer: int, peer_buffer_address: int, length: int + ) -> int: + """Synchronously transfer data to the specified address.""" + + ret = self.engine.transfer_sync_write( + session_id, buffer, peer_buffer_address, length + ) + if ret < 0: + logger.error("Mooncake Transfer Engine Return Error.") + raise RuntimeError("Mooncake Transfer Engine Return Error.") + return ret + + def batch_transfer_sync( + self, session_id: str, buffers: list[int], peer_buffer_addresses: list[int], lengths: list[int] + ) -> int: + """Synchronously transfer data to the specified address.""" + ret = self.engine.batch_transfer_sync_write( + session_id, buffers, peer_buffer_addresses, lengths + ) + if ret < 0: + logger.error("Mooncake Transfer Engine Return Error.") + raise RuntimeError("Mooncake Transfer Engine Return Error.") + return ret + + def get_localhost(self): + return self.hostname + + def get_session_id(self): + return self.session_id \ No newline at end of file diff --git a/mooncake-wheel/tests/test_inference.py b/mooncake-wheel/tests/test_inference.py new file mode 100644 index 000000000..170c0817e --- /dev/null +++ b/mooncake-wheel/tests/test_inference.py @@ -0,0 +1,28 @@ +from load_model import get_tensor_info +import yaml +import torch +from new_tensor import create_tensor, get_dtype +from mooncake_inference import MooncakeInference +import logging +import time + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S.%f') + +def inference_recv(conf): + models = get_tensor_info(conf["model_path"]) + total = len(models) + + mckInference = MooncakeInference(conf) + + now = 0 + while now < total: + d = mckInference.recv_tensor() + now += 1 + time.sleep(1) + + +if __name__ == "__main__": + conf = yaml.load(open("torch.yaml"), Loader=yaml.FullLoader) + logging.info(f"configuration: {conf}") + inference_recv(conf) + diff --git a/mooncake-wheel/tests/test_training.py b/mooncake-wheel/tests/test_training.py new file mode 100644 index 000000000..a243021db --- /dev/null +++ b/mooncake-wheel/tests/test_training.py @@ -0,0 +1,28 @@ +from new_tensor import create_tensor +from load_model import get_tensor_info +from mooncake_training import MooncakeTraining +import yaml +import logging + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S.%f') + +def training_send(conf): + models = get_tensor_info(conf["model_path"]) + + mckTrain = MooncakeTraining(conf) + + now_size = 0 + for model in models: + name, shape, dtype, size = model[0], model[1], model[2], model[3] * 2 + tensor = create_tensor(shape, dtype, gpu_id=conf["training_gpu_id"]) + mckTrain.reg_tensor(name, tensor) + now_size += size + mckTrain.reg_tensor(None, None) + + print("total: ", now_size, mckTrain.report()) + +if __name__ == "__main__": + conf = yaml.load(open("torch.yaml"), Loader=yaml.FullLoader) + logging.info(f"Server configuration: {conf}") + training_send(conf) + diff --git a/mooncake-wheel/tests/torch.yaml b/mooncake-wheel/tests/torch.yaml new file mode 100644 index 000000000..2e9f1be19 --- /dev/null +++ b/mooncake-wheel/tests/torch.yaml @@ -0,0 +1,9 @@ +inference_ip: "192.168.0.145" +training_ip: "192.168.0.146" +inference_gpu_id: 6 +training_gpu_id: 7 +port: 5555 +bulk_size: 5 # 5GB +send_bulk: 15 # 15GB +threads: 1 +model_path: "little.txt" \ No newline at end of file