Skip to content

Commit 4fa75d6

Browse files
committed
[store] add get_tensor_into() and batch_get_tensor_into()
Signed-off-by: Cruz Zhao <[email protected]>
1 parent 5c3d04f commit 4fa75d6

File tree

4 files changed

+216
-6
lines changed

4 files changed

+216
-6
lines changed

mooncake-integration/store/store_py.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,176 @@ class MooncakeStorePyWrapper {
503503
return results_list;
504504
}
505505

506+
int64_t get_tensor_into(const std::string &key, uintptr_t buffer_ptr,
507+
size_t size) {
508+
void *buffer = reinterpret_cast<void *>(buffer_ptr);
509+
if (!is_client_initialized()) {
510+
LOG(ERROR) << "Client is not initialized";
511+
return to_py_ret(ErrorCode::INVALID_PARAMS);
512+
}
513+
514+
if (use_dummy_client_) {
515+
LOG(ERROR) << "get_tensor is not supported for dummy client now";
516+
return to_py_ret(ErrorCode::INVALID_PARAMS);
517+
}
518+
519+
try {
520+
// Section with GIL released
521+
py::gil_scoped_release release_gil;
522+
auto total_length = store_->get_into_internal(key, buffer, size);
523+
if (!total_length.has_value()) {
524+
py::gil_scoped_acquire acquire_gil;
525+
return to_py_ret(ErrorCode::INVALID_PARAMS);
526+
}
527+
528+
TensorMetadata metadata;
529+
// Copy data from buffer to contiguous memory
530+
memcpy(&metadata, static_cast<char *>(buffer),
531+
sizeof(TensorMetadata));
532+
533+
if (metadata.ndim < 0 || metadata.ndim > 4) {
534+
py::gil_scoped_acquire acquire_gil;
535+
LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim;
536+
return to_py_ret(ErrorCode::INVALID_PARAMS);
537+
}
538+
539+
TensorDtype dtype_enum = static_cast<TensorDtype>(metadata.dtype);
540+
if (dtype_enum == TensorDtype::UNKNOWN) {
541+
py::gil_scoped_acquire acquire_gil;
542+
LOG(ERROR) << "Unknown tensor dtype!";
543+
return to_py_ret(ErrorCode::INVALID_PARAMS);
544+
}
545+
546+
size_t tensor_size = total_length.value() - sizeof(TensorMetadata);
547+
if (tensor_size == 0) {
548+
py::gil_scoped_acquire acquire_gil;
549+
LOG(ERROR) << "Invalid data format: no tensor data found";
550+
return to_py_ret(ErrorCode::INVALID_PARAMS);
551+
}
552+
553+
py::gil_scoped_acquire acquire_gil;
554+
// Convert bytes to tensor using torch.from_numpy
555+
pybind11::object np_array;
556+
int dtype_index = static_cast<int>(dtype_enum);
557+
if (dtype_index < 0 ||
558+
dtype_index >= static_cast<int>(array_creators.size())) {
559+
LOG(ERROR) << "Unsupported dtype enum: " << dtype_index;
560+
return to_py_ret(ErrorCode::INVALID_PARAMS);
561+
}
562+
563+
return total_length.value();
564+
565+
} catch (const pybind11::error_already_set &e) {
566+
LOG(ERROR) << "Failed to get tensor data: " << e.what();
567+
return to_py_ret(ErrorCode::INVALID_PARAMS);
568+
}
569+
}
570+
571+
pybind11::list batch_get_tensor_into(const std::vector<std::string> &keys,
572+
const std::vector<uintptr_t> &buffer_ptrs,
573+
const std::vector<size_t> &sizes) {
574+
std::vector<void *> buffers;
575+
buffers.reserve(buffer_ptrs.size());
576+
for (uintptr_t ptr : buffer_ptrs) {
577+
buffers.push_back(reinterpret_cast<void *>(ptr));
578+
}
579+
580+
if (!is_client_initialized()) {
581+
LOG(ERROR) << "Client is not initialized";
582+
py::list empty_list;
583+
for (size_t i = 0; i < keys.size(); ++i) {
584+
empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
585+
}
586+
return empty_list;
587+
}
588+
589+
if (use_dummy_client_) {
590+
LOG(ERROR) << "batch_get_tensor is not supported for dummy client "
591+
"now";
592+
py::list empty_list;
593+
for (size_t i = 0; i < keys.size(); ++i) {
594+
empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
595+
}
596+
return empty_list;
597+
}
598+
599+
// Phase 1: Batch Get Buffers (GIL Released)
600+
py::gil_scoped_release release_gil;
601+
// This internal call already handles logging for query failures
602+
auto total_lengths =
603+
store_->batch_get_into_internal(keys, buffers, sizes);
604+
605+
py::list results_list;
606+
try {
607+
py::gil_scoped_acquire acquire_gil;
608+
auto torch = torch_module();
609+
610+
for (size_t i = 0; i < total_lengths.size(); i++) {
611+
const auto &buffer = buffers[i];
612+
if (!buffer) {
613+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
614+
continue;
615+
}
616+
617+
auto total_length = total_lengths[i];
618+
if (!total_length.has_value()) {
619+
LOG(ERROR) << "Invalid data format: insufficient data for"
620+
"metadata";
621+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
622+
continue;
623+
}
624+
if (total_length.value() <=
625+
static_cast<long>(sizeof(TensorMetadata))) {
626+
LOG(ERROR) << "Invalid data format: insufficient data for "
627+
"metadata";
628+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
629+
continue;
630+
}
631+
632+
TensorMetadata metadata;
633+
memcpy(&metadata, static_cast<char *>(buffer),
634+
sizeof(TensorMetadata));
635+
636+
if (metadata.ndim < 0 || metadata.ndim > 4) {
637+
LOG(ERROR)
638+
<< "Invalid tensor metadata: ndim=" << metadata.ndim;
639+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
640+
continue;
641+
}
642+
643+
TensorDtype dtype_enum =
644+
static_cast<TensorDtype>(metadata.dtype);
645+
if (dtype_enum == TensorDtype::UNKNOWN) {
646+
LOG(ERROR) << "Unknown tensor dtype!";
647+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
648+
continue;
649+
}
650+
651+
size_t tensor_size =
652+
total_length.value() - sizeof(TensorMetadata);
653+
if (tensor_size == 0) {
654+
LOG(ERROR) << "Invalid data format: no tensor data found";
655+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
656+
continue;
657+
}
658+
659+
int dtype_index = static_cast<int>(dtype_enum);
660+
if (dtype_index < 0 ||
661+
dtype_index >= static_cast<int>(array_creators.size())) {
662+
LOG(ERROR) << "Unsupported dtype enum: " << dtype_index;
663+
results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS));
664+
continue;
665+
}
666+
667+
results_list.append(total_length.value());
668+
}
669+
} catch (const pybind11::error_already_set &e) {
670+
LOG(ERROR) << "Failed during batch tensor deserialization: "
671+
<< e.what();
672+
}
673+
return results_list;
674+
}
675+
506676
int put_tensor_with_tp(const std::string &key, pybind11::object tensor,
507677
int tp_rank = 0, int tp_size = 1,
508678
int split_dim = 0) {
@@ -1241,6 +1411,15 @@ PYBIND11_MODULE(store, m) {
12411411
.def("pub_tensor", &MooncakeStorePyWrapper::pub_tensor, py::arg("key"),
12421412
py::arg("tensor"), py::arg("config") = ReplicateConfig{},
12431413
"Publish a PyTorch tensor with configurable replication settings")
1414+
.def("get_tensor_into", &MooncakeStorePyWrapper::get_tensor_into,
1415+
py::arg("key"), py::arg("buffer_ptr"), py::arg("size"),
1416+
"Get tensor directly into a pre-allocated buffer")
1417+
.def("batch_get_tensor_into",
1418+
&MooncakeStorePyWrapper::batch_get_tensor_into, py::arg("keys"),
1419+
py::arg("buffer_ptrs"), py::arg("sizes"),
1420+
"Get tensors directly into pre-allocated buffers for "
1421+
"multiple "
1422+
"keys")
12441423
.def(
12451424
"register_buffer",
12461425
[](MooncakeStorePyWrapper &self, uintptr_t buffer_ptr,

mooncake-store/include/dummy_client.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,16 @@ class DummyClient : public PyClient {
7171

7272
int unregister_buffer(void *buffer);
7373

74+
tl::expected<int64_t, ErrorCode> get_into_internal(const std::string &key,
75+
void *buffer,
76+
size_t size);
77+
7478
int64_t get_into(const std::string &key, void *buffer, size_t size);
7579

80+
std::vector<tl::expected<int64_t, ErrorCode>> batch_get_into_internal(
81+
const std::vector<std::string> &keys,
82+
const std::vector<void *> &buffers, const std::vector<size_t> &sizes);
83+
7684
std::vector<int64_t> batch_get_into(const std::vector<std::string> &keys,
7785
const std::vector<void *> &buffers,
7886
const std::vector<size_t> &sizes);
@@ -221,4 +229,4 @@ class DummyClient : public PyClient {
221229
volatile bool connected_ = false;
222230
};
223231

224-
} // namespace mooncake
232+
} // namespace mooncake

mooncake-store/include/pyclient.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,17 @@ class PyClient {
5353

5454
virtual int unregister_buffer(void *buffer) = 0;
5555

56+
virtual tl::expected<int64_t, ErrorCode> get_into_internal(
57+
const std::string &key, void *buffer, size_t size) = 0;
58+
5659
virtual int64_t get_into(const std::string &key, void *buffer,
5760
size_t size) = 0;
5861

62+
virtual std::vector<tl::expected<int64_t, ErrorCode>>
63+
batch_get_into_internal(const std::vector<std::string> &keys,
64+
const std::vector<void *> &buffers,
65+
const std::vector<size_t> &sizes) = 0;
66+
5967
virtual std::vector<int64_t> batch_get_into(
6068
const std::vector<std::string> &keys,
6169
const std::vector<void *> &buffers,

mooncake-store/src/dummy_client.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,12 @@ std::vector<std::shared_ptr<BufferHandle>> DummyClient::batch_get_buffer(
511511
return std::vector<std::shared_ptr<BufferHandle>>();
512512
}
513513

514+
tl::expected<int64_t, ErrorCode> DummyClient::get_into_internal(
515+
const std::string& key, void* buffer, size_t size) {
516+
// TODO: implement this function
517+
return tl::unexpected(ErrorCode::INVALID_PARAMS);
518+
}
519+
514520
int64_t DummyClient::get_into(const std::string& key, void* buffer,
515521
size_t size) {
516522
// TODO: implement this function
@@ -548,16 +554,25 @@ int DummyClient::put_from(const std::string& key, void* buffer, size_t size,
548554
return -1;
549555
}
550556

551-
std::vector<int64_t> DummyClient::batch_get_into(
552-
const std::vector<std::string>& keys, const std::vector<void*>& buffer_ptrs,
553-
const std::vector<size_t>& sizes) {
557+
std::vector<tl::expected<int64_t, ErrorCode>>
558+
DummyClient::batch_get_into_internal(const std::vector<std::string>& keys,
559+
const std::vector<void*>& buffer_ptrs,
560+
const std::vector<size_t>& sizes) {
554561
std::vector<uint64_t> buffers;
555562
for (auto ptr : buffer_ptrs) {
556563
buffers.push_back(reinterpret_cast<uint64_t>(ptr));
557564
}
558-
auto internal_results =
565+
auto results =
559566
invoke_batch_rpc<&RealClient::batch_get_into_dummy_helper, int64_t>(
560567
keys.size(), keys, buffers, sizes, client_id_);
568+
569+
return results;
570+
}
571+
572+
std::vector<int64_t> DummyClient::batch_get_into(
573+
const std::vector<std::string>& keys, const std::vector<void*>& buffer_ptrs,
574+
const std::vector<size_t>& sizes) {
575+
auto internal_results = batch_get_into_internal(keys, buffer_ptrs, sizes);
561576
std::vector<int64_t> results;
562577
results.reserve(internal_results.size());
563578

@@ -688,4 +703,4 @@ void DummyClient::ping_thread_main() {
688703
}
689704
}
690705

691-
} // namespace mooncake
706+
} // namespace mooncake

0 commit comments

Comments
 (0)