@@ -503,6 +503,176 @@ class MooncakeStorePyWrapper {
503503 return results_list;
504504 }
505505
506+ int64_t get_tensor_into (const std::string &key, uintptr_t buffer_ptr,
507+ size_t size) {
508+ void *buffer = reinterpret_cast <void *>(buffer_ptr);
509+ if (!is_client_initialized ()) {
510+ LOG (ERROR) << " Client is not initialized" ;
511+ return to_py_ret (ErrorCode::INVALID_PARAMS);
512+ }
513+
514+ if (use_dummy_client_) {
515+ LOG (ERROR) << " get_tensor is not supported for dummy client now" ;
516+ return to_py_ret (ErrorCode::INVALID_PARAMS);
517+ }
518+
519+ try {
520+ // Section with GIL released
521+ py::gil_scoped_release release_gil;
522+ auto total_length = store_->get_into_internal (key, buffer, size);
523+ if (!total_length.has_value ()) {
524+ py::gil_scoped_acquire acquire_gil;
525+ return to_py_ret (ErrorCode::INVALID_PARAMS);
526+ }
527+
528+ TensorMetadata metadata;
529+ // Copy data from buffer to contiguous memory
530+ memcpy (&metadata, static_cast <char *>(buffer),
531+ sizeof (TensorMetadata));
532+
533+ if (metadata.ndim < 0 || metadata.ndim > 4 ) {
534+ py::gil_scoped_acquire acquire_gil;
535+ LOG (ERROR) << " Invalid tensor metadata: ndim=" << metadata.ndim ;
536+ return to_py_ret (ErrorCode::INVALID_PARAMS);
537+ }
538+
539+ TensorDtype dtype_enum = static_cast <TensorDtype>(metadata.dtype );
540+ if (dtype_enum == TensorDtype::UNKNOWN) {
541+ py::gil_scoped_acquire acquire_gil;
542+ LOG (ERROR) << " Unknown tensor dtype!" ;
543+ return to_py_ret (ErrorCode::INVALID_PARAMS);
544+ }
545+
546+ size_t tensor_size = total_length.value () - sizeof (TensorMetadata);
547+ if (tensor_size == 0 ) {
548+ py::gil_scoped_acquire acquire_gil;
549+ LOG (ERROR) << " Invalid data format: no tensor data found" ;
550+ return to_py_ret (ErrorCode::INVALID_PARAMS);
551+ }
552+
553+ py::gil_scoped_acquire acquire_gil;
554+ // Convert bytes to tensor using torch.from_numpy
555+ pybind11::object np_array;
556+ int dtype_index = static_cast <int >(dtype_enum);
557+ if (dtype_index < 0 ||
558+ dtype_index >= static_cast <int >(array_creators.size ())) {
559+ LOG (ERROR) << " Unsupported dtype enum: " << dtype_index;
560+ return to_py_ret (ErrorCode::INVALID_PARAMS);
561+ }
562+
563+ return total_length.value ();
564+
565+ } catch (const pybind11::error_already_set &e) {
566+ LOG (ERROR) << " Failed to get tensor data: " << e.what ();
567+ return to_py_ret (ErrorCode::INVALID_PARAMS);
568+ }
569+ }
570+
571+ pybind11::list batch_get_tensor_into (const std::vector<std::string> &keys,
572+ const std::vector<uintptr_t > &buffer_ptrs,
573+ const std::vector<size_t > &sizes) {
574+ std::vector<void *> buffers;
575+ buffers.reserve (buffer_ptrs.size ());
576+ for (uintptr_t ptr : buffer_ptrs) {
577+ buffers.push_back (reinterpret_cast <void *>(ptr));
578+ }
579+
580+ if (!is_client_initialized ()) {
581+ LOG (ERROR) << " Client is not initialized" ;
582+ py::list empty_list;
583+ for (size_t i = 0 ; i < keys.size (); ++i) {
584+ empty_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
585+ }
586+ return empty_list;
587+ }
588+
589+ if (use_dummy_client_) {
590+ LOG (ERROR) << " batch_get_tensor is not supported for dummy client "
591+ " now" ;
592+ py::list empty_list;
593+ for (size_t i = 0 ; i < keys.size (); ++i) {
594+ empty_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
595+ }
596+ return empty_list;
597+ }
598+
599+ // Phase 1: Batch Get Buffers (GIL Released)
600+ py::gil_scoped_release release_gil;
601+ // This internal call already handles logging for query failures
602+ auto total_lengths =
603+ store_->batch_get_into_internal (keys, buffers, sizes);
604+
605+ py::list results_list;
606+ try {
607+ py::gil_scoped_acquire acquire_gil;
608+ auto torch = torch_module ();
609+
610+ for (size_t i = 0 ; i < total_lengths.size (); i++) {
611+ const auto &buffer = buffers[i];
612+ if (!buffer) {
613+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
614+ continue ;
615+ }
616+
617+ auto total_length = total_lengths[i];
618+ if (!total_length.has_value ()) {
619+ LOG (ERROR) << " Invalid data format: insufficient data for"
620+ " metadata" ;
621+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
622+ continue ;
623+ }
624+ if (total_length.value () <=
625+ static_cast <long >(sizeof (TensorMetadata))) {
626+ LOG (ERROR) << " Invalid data format: insufficient data for "
627+ " metadata" ;
628+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
629+ continue ;
630+ }
631+
632+ TensorMetadata metadata;
633+ memcpy (&metadata, static_cast <char *>(buffer),
634+ sizeof (TensorMetadata));
635+
636+ if (metadata.ndim < 0 || metadata.ndim > 4 ) {
637+ LOG (ERROR)
638+ << " Invalid tensor metadata: ndim=" << metadata.ndim ;
639+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
640+ continue ;
641+ }
642+
643+ TensorDtype dtype_enum =
644+ static_cast <TensorDtype>(metadata.dtype );
645+ if (dtype_enum == TensorDtype::UNKNOWN) {
646+ LOG (ERROR) << " Unknown tensor dtype!" ;
647+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
648+ continue ;
649+ }
650+
651+ size_t tensor_size =
652+ total_length.value () - sizeof (TensorMetadata);
653+ if (tensor_size == 0 ) {
654+ LOG (ERROR) << " Invalid data format: no tensor data found" ;
655+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
656+ continue ;
657+ }
658+
659+ int dtype_index = static_cast <int >(dtype_enum);
660+ if (dtype_index < 0 ||
661+ dtype_index >= static_cast <int >(array_creators.size ())) {
662+ LOG (ERROR) << " Unsupported dtype enum: " << dtype_index;
663+ results_list.append (to_py_ret (ErrorCode::INVALID_PARAMS));
664+ continue ;
665+ }
666+
667+ results_list.append (total_length.value ());
668+ }
669+ } catch (const pybind11::error_already_set &e) {
670+ LOG (ERROR) << " Failed during batch tensor deserialization: "
671+ << e.what ();
672+ }
673+ return results_list;
674+ }
675+
506676 int put_tensor_with_tp (const std::string &key, pybind11::object tensor,
507677 int tp_rank = 0 , int tp_size = 1 ,
508678 int split_dim = 0 ) {
@@ -1241,6 +1411,15 @@ PYBIND11_MODULE(store, m) {
12411411 .def (" pub_tensor" , &MooncakeStorePyWrapper::pub_tensor, py::arg (" key" ),
12421412 py::arg (" tensor" ), py::arg (" config" ) = ReplicateConfig{},
12431413 " Publish a PyTorch tensor with configurable replication settings" )
1414+ .def (" get_tensor_into" , &MooncakeStorePyWrapper::get_tensor_into,
1415+ py::arg (" key" ), py::arg (" buffer_ptr" ), py::arg (" size" ),
1416+ " Get tensor directly into a pre-allocated buffer" )
1417+ .def (" batch_get_tensor_into" ,
1418+ &MooncakeStorePyWrapper::batch_get_tensor_into, py::arg (" keys" ),
1419+ py::arg (" buffer_ptrs" ), py::arg (" sizes" ),
1420+ " Get tensors directly into pre-allocated buffers for "
1421+ " multiple "
1422+ " keys" )
12441423 .def (
12451424 " register_buffer" ,
12461425 [](MooncakeStorePyWrapper &self, uintptr_t buffer_ptr,
0 commit comments