Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 204 additions & 23 deletions mooncake-integration/store/store_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,61 @@ int DistributedObjectStore::allocateSlices(
return 0;
}

int DistributedObjectStore::allocateBatchedSlices(
const std::vector<std::string> &keys,
const std::vector<std::span<const char>> &values,
std::unordered_map<std::string, std::vector<mooncake::Slice>>
&batched_slices) {
for (size_t i = 0; i < keys.size(); ++i) {
uint64_t offset = 0;
const auto &value = values[i];
std::vector<Slice> slices;
while (offset < value.size()) {
auto chunk_size = std::min(value.size() - offset, kMaxSliceSize);
auto ptr = client_buffer_allocator_->allocate(chunk_size);
if (!ptr) {
return 1;
}
memcpy(ptr, value.data() + offset, chunk_size);
slices.emplace_back(Slice{ptr, chunk_size});
offset += chunk_size;
}
batched_slices.emplace(keys[i], std::move(slices));
}
return 0;
}

int DistributedObjectStore::allocateBatchedSlices(
const std::vector<std::string> &keys,
std::unordered_map<std::string, std::vector<mooncake::Slice>>
&batched_slices,
const mooncake::Client::BatchObjectInfo &batched_object_info,
std::unordered_map<std::string, uint64_t> &str_length_map) {
if (batched_object_info.batch_replica_list.empty()) return -1;
for (const auto &key : keys) {
auto object_info_it = batched_object_info.batch_replica_list.find(key);
if (object_info_it == batched_object_info.batch_replica_list.end()) {
LOG(ERROR) << "Key not found: " << key;
return 1;
}
// Get first replica
auto &replica = object_info_it->second[0];
uint64_t length = 0;
for (auto &handle : replica.buffer_descriptors) {
auto chunk_size = handle.size_;
assert(chunk_size <= kMaxSliceSize);
auto ptr = client_buffer_allocator_->allocate(chunk_size);
if (!ptr) {
return 1;
}
batched_slices[key].emplace_back(Slice{ptr, chunk_size});
length += chunk_size;
}
str_length_map.emplace(key, length);
}
return 0;
}

char *DistributedObjectStore::exportSlices(
const std::vector<mooncake::Slice> &slices, uint64_t length) {
char *buf = new char[length + 1];
Expand Down Expand Up @@ -377,6 +432,39 @@ int DistributedObjectStore::put(const std::string &key,
return 0;
}

int DistributedObjectStore::put_batch(
const std::vector<std::string> &keys,
const std::vector<std::span<const char>> &values) {
if (!client_) {
LOG(ERROR) << "Client is not initialized";
return 1;
}
if (keys.size() != values.size()) {
LOG(ERROR) << "Key and value size mismatch";
}
std::unordered_map<std::string, std::vector<mooncake::Slice>>
batched_slices;
int ret = allocateBatchedSlices(keys, values, batched_slices);
if (ret) {
LOG(ERROR) << "Failed to allocate slices for put_batch operation";
return ret;
}

ReplicateConfig config;
config.replica_num = 1;
ErrorCode error_code = client_->BatchPut(keys, batched_slices, config);
if (error_code != ErrorCode::OK) {
LOG(ERROR) << "BatchPut operation failed with error: "
<< toString(error_code);
return toInt(error_code);
}

for (auto &slice : batched_slices) {
freeSlices(slice.second);
}
return 0;
}

int DistributedObjectStore::put_parts(
const std::string &key, std::vector<std::span<const char>> values) {
if (!client_) {
Expand Down Expand Up @@ -466,6 +554,76 @@ pybind11::bytes DistributedObjectStore::get(const std::string &key) {
return result;
}

std::vector<pybind11::bytes> DistributedObjectStore::get_batch(
const std::vector<std::string> &keys) {
const auto kNullString = pybind11::bytes("\0", 0);
if (!client_) {
LOG(ERROR) << "Client is not initialized";
return {kNullString};
}
std::unordered_set<std::string> seen;
for (const auto &key : keys) {
if (!seen.insert(key).second) {
LOG(ERROR) << "Duplicate key not supported for Batch API, key: "
<< key;
return {kNullString};
}
}

std::vector<pybind11::bytes> results;
mooncake::Client::BatchObjectInfo batched_object_info;
std::unordered_map<std::string, std::vector<mooncake::Slice>>
batched_slices;
std::unordered_map<std::string, uint64_t> str_length_map;
{
py::gil_scoped_release release_gil;
ErrorCode error_code = client_->BatchQuery(keys, batched_object_info);
if (error_code != ErrorCode::OK) {
py::gil_scoped_acquire acquire_gil;
return {kNullString};
} else {
int ret = allocateBatchedSlices(
keys, batched_slices, batched_object_info, str_length_map);
if (ret) {
py::gil_scoped_acquire acquire_gil;
return {kNullString};
}
error_code =
client_->BatchGet(keys, batched_object_info, batched_slices);
if (error_code != ErrorCode::OK) {
py::gil_scoped_acquire acquire_gil;
return {kNullString};
}
}
for (const auto &key : keys) {
if (batched_slices[key].size() == 1 &&
batched_slices[key][0].size == str_length_map[key]) {
results.push_back(pybind11::bytes(
static_cast<char *>(batched_slices[key][0].ptr),
str_length_map[key]));
} else {
char *exported_str_ptr =
exportSlices(batched_slices[key], str_length_map[key]);
if (!exported_str_ptr) {
return {kNullString};
} else {
results.push_back(
pybind11::bytes(exported_str_ptr, str_length_map[key]));
delete[] exported_str_ptr;
}
}
}
if (results.size() != keys.size()) {
LOG(ERROR) << "Results size does not match keys size";
return {kNullString};
}
for (auto &slice : batched_slices) {
freeSlices(slice.second);
}
return results;
}
}

int DistributedObjectStore::remove(const std::string &key) {
if (!client_) {
LOG(ERROR) << "Client is not initialized";
Expand Down Expand Up @@ -809,6 +967,7 @@ PYBIND11_MODULE(store, m) {
.def("setup", &DistributedObjectStore::setup)
.def("init_all", &DistributedObjectStore::initAll)
.def("get", &DistributedObjectStore::get)
.def("get_batch", &DistributedObjectStore::get_batch)
.def("get_buffer", &DistributedObjectStore::get_buffer,
py::call_guard<py::gil_scoped_release>(),
py::return_value_policy::take_ownership)
Expand Down Expand Up @@ -867,27 +1026,49 @@ PYBIND11_MODULE(store, m) {
static_cast<char *>(info.ptr),
static_cast<size_t>(info.size)));
})
.def("put_parts", [](DistributedObjectStore &self,
const std::string &key, py::args parts) {
// 1) Python buffer → span
std::vector<py::buffer_info> infos;
std::vector<std::span<const char>> spans;
infos.reserve(parts.size());
spans.reserve(parts.size());

for (auto &obj : parts) {
py::buffer buf = py::reinterpret_borrow<py::buffer>(obj);
infos.emplace_back(buf.request(false));
const auto &info = infos.back();
if (info.ndim != 1 || info.itemsize != 1)
throw std::runtime_error("parts must be 1-D bytes-like");

spans.emplace_back(static_cast<const char *>(info.ptr),
static_cast<size_t>(info.size));
}

// 2) Call C++ function
py::gil_scoped_release unlock;
return self.put_parts(key, spans);
});
.def("put_parts",
[](DistributedObjectStore &self, const std::string &key,
py::args parts) {
// 1) Python buffer → span
std::vector<py::buffer_info> infos;
std::vector<std::span<const char>> spans;
infos.reserve(parts.size());
spans.reserve(parts.size());

for (auto &obj : parts) {
py::buffer buf = py::reinterpret_borrow<py::buffer>(obj);
infos.emplace_back(buf.request(false));
const auto &info = infos.back();
if (info.ndim != 1 || info.itemsize != 1)
throw std::runtime_error(
"parts must be 1-D bytes-like");

spans.emplace_back(static_cast<const char *>(info.ptr),
static_cast<size_t>(info.size));
}

// 2) Call C++ function
py::gil_scoped_release unlock;
return self.put_parts(key, spans);
})
.def(
"put_batch",
[](DistributedObjectStore &self,
const std::vector<std::string> &keys,
const std::vector<py::bytes> &py_values) {
std::vector<std::string> temp_values;
temp_values.reserve(py_values.size());
for (const auto &value : py_values) {
temp_values.emplace_back(value.cast<std::string>());
}

std::vector<std::span<const char>> spans;
spans.reserve(temp_values.size());
for (const auto &s : temp_values) {
spans.emplace_back(s.data(), s.size());
}

return self.put_batch(keys, spans);
},
py::arg("keys"), py::arg("values"));
}
19 changes: 19 additions & 0 deletions mooncake-integration/store/store_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,14 @@ class DistributedObjectStore {
int put_parts(const std::string &key,
std::vector<std::span<const char>> values);

int put_batch(const std::vector<std::string> &keys,
const std::vector<std::span<const char>> &values);

pybind11::bytes get(const std::string &key);

std::vector<pybind11::bytes> get_batch(
const std::vector<std::string> &keys);

/**
* @brief Get a buffer containing the data for a key
* @param key Key to get data for
Expand Down Expand Up @@ -194,6 +200,19 @@ class DistributedObjectStore {
int allocateSlicesPacked(std::vector<mooncake::Slice> &slices,
const std::vector<std::span<const char>> &parts);

int allocateBatchedSlices(
const std::vector<std::string> &keys,
std::unordered_map<std::string, std::vector<mooncake::Slice>>
&batched_slices,
const mooncake::Client::BatchObjectInfo &batched_object_info,
std::unordered_map<std::string, uint64_t> &str_length_map);

int allocateBatchedSlices(
const std::vector<std::string> &keys,
const std::vector<std::span<const char>> &values,
std::unordered_map<std::string, std::vector<mooncake::Slice>>
&batched_slices);

char *exportSlices(const std::vector<mooncake::Slice> &slices,
uint64_t length);

Expand Down
2 changes: 2 additions & 0 deletions mooncake-wheel/mooncake/transfer_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Import for backward compatibility

31 changes: 31 additions & 0 deletions mooncake-wheel/mooncake/transfer_engine/mooncake_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import zmq

from transfer_engine import MooncakeTransferEngine
Copy link

Copilot AI Jun 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a relative import for the local transfer_engine module to avoid import errors (e.g., from .transfer_engine import MooncakeTransferEngine).

Suggested change
from transfer_engine import MooncakeTransferEngine
from .transfer_engine import MooncakeTransferEngine

Copilot uses AI. Check for mistakes.
from new_tensor import create_tensor

class MooncakeInference:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these class names do not precisely reveal what they are actually doing. Maybe the usage is for RL training and inference, but the names are very confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Will fix

def __init__(self, config: dict):
self.config = config
self._engine = MooncakeTransferEngine(
hostname = config["inference_ip"],
gpu_id = config["inference_gpu_id"], # Using GPU
ib_device = None # No specific IB device
)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{config['training_ip']}:{config['port']}")

def __del__(self):
self._socket.close()
self._context.destroy()

def recv_tensor(self) -> dict:
self._socket.send_json({"session_id": self._engine.get_session_id()})
ret = self._socket.recv_json() # ["name", "shape", "dtype"]
tensor = create_tensor(ret[1], ret[2], self.config["inference_gpu_id"])
self._engine.register(tensor.data_ptr(), tensor.numel() * tensor.element_size())
self._socket.send_json({"ptr": tensor.data_ptr()})
self._socket.recv_json()
self._engine.deregister(tensor.data_ptr())
return {ret[0]: tensor}
55 changes: 55 additions & 0 deletions mooncake-wheel/mooncake/transfer_engine/mooncake_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import zmq
import time
from threading import Thread
from queue import Queue
from new_tensor import get_dtype_str

from transfer_engine import MooncakeTransferEngine
Copy link

Copilot AI Jun 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a relative import for the local transfer_engine module to avoid import errors (e.g., from .transfer_engine import MooncakeTransferEngine).

Suggested change
from transfer_engine import MooncakeTransferEngine
from .transfer_engine import MooncakeTransferEngine

Copilot uses AI. Check for mistakes.

class MooncakeTraining(Thread):
def __init__(self, config: dict):
super().__init__()
self._stop = False
self.config = config
self._queue = Queue(maxsize=config["send_bulk"])
self.engine = MooncakeTransferEngine(
hostname = config["training_ip"],
gpu_id = config["training_gpu_id"], # Using GPU
ib_device = None # No specific IB device
)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
self._socket.bind(f"tcp://*:{config['port']}")
self._report = 0
self.start()

def __del__(self):
self._socket.close()
self._context.destroy()


def reg_tensor(self, name: str, tensor: torch.Tensor):
self._queue.put((name, tensor))

def run(self):
while True:
name, tensor = self._queue.get()
if name is None:
break
ret = self._socket.recv_json() # get req.
session_id = ret["session_id"]
self._socket.send_json([name, tensor.shape, get_dtype_str(tensor.dtype)])
ptr, size = tensor.data_ptr(), tensor.numel() * tensor.element_size()
self.engine.register(ptr, size)
ret = self._socket.recv_json()
t1 = time.time()
self.engine.transfer_sync(session_id, ptr, ret["ptr"], size)
t2 = time.time()
self._report += t2 - t1
self._socket.send_json(ret)
self.engine.deregister(ptr)

def report(self):
return self._report

Loading
Loading