Skip to content

Commit 40d5b59

Browse files
zxpdemonionickyc975
authored andcommitted
[store] zero copy for get_tensor() and batch_get_tensor() (kvcache-ai#1192)
1 parent fc07213 commit 40d5b59

File tree

6 files changed

+779
-47
lines changed

6 files changed

+779
-47
lines changed

docs/source/python-api-reference/mooncake-store.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,91 @@ def batch_get_tensor_with_tp(self, base_keys: List[str], tp_rank: int = 0, tp_si
10041004

10051005
---
10061006

1007+
### PyTorch Tensor Operations (Zero Copy)
1008+
1009+
These methods provide direct support for storing and retrieving PyTorch tensors. They automatically handle serialization and metadata, and include built-in support for **Tensor Parallelism (TP)** by automatically splitting and reconstructing tensor shards.
1010+
1011+
⚠️ **Note**: These methods require `torch` to be installed and available in the environment.
1012+
1013+
#### get_tensor_into()
1014+
1015+
Get a PyTorch tensor from the store directly into a pre-allocated buffer.
1016+
1017+
```python
1018+
def get_tensor_into(self, key: str, buffer_ptr: int, size: int) -> torch.Tensor
1019+
```
1020+
1021+
**Parameters:**
1022+
1023+
- `key` (str): Base identifier of the tensor.
1024+
- `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered.
1025+
- `size` (int): The size of buffer.
1026+
1027+
**Returns:**
1028+
1029+
- `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found.
1030+
1031+
#### batch_get_tensor()
1032+
1033+
Get a batch of PyTorch tensor from the store directly into a pre-allocated buffer.
1034+
1035+
```python
1036+
def batch_get_tensor_into(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int]) -> List[torch.Tensor]
1037+
```
1038+
1039+
**Parameters:**
1040+
1041+
- `base_keys` (List[str]): List of base identifiers.
1042+
- `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered.
1043+
- `sizes` (List[int]): List of the size of buffers.
1044+
1045+
**Returns:**
1046+
1047+
- `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys.
1048+
1049+
#### get_tensor_with_tp_into()
1050+
1051+
Get a PyTorch tensor from the store, specifically retrieving the shard corresponding to the given Tensor Parallel rank, directly into the pre-allocated buffer.
1052+
1053+
```python
1054+
def get_tensor_with_tp_into(self, key: str, buffer_ptr: int, size: int, tp_rank: int = 0, tp_size: int = 1, split_dim: int = 0) -> torch.Tensor
1055+
```
1056+
1057+
**Parameters:**
1058+
1059+
- `key` (str): Base identifier of the tensor.
1060+
- `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered.
1061+
- `size` (int): The size of buffer.
1062+
- `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). Fetches key `key_tp_{rank}` if `tp_size > 1`.
1063+
- `tp_size` (int): Total tensor parallel size (default: 1).
1064+
- `split_dim` (int): The dimension used during splitting (default: 0).
1065+
1066+
**Returns:**
1067+
1068+
- `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found.
1069+
1070+
#### batch_get_tensor_with_tp_into()
1071+
1072+
Get a batch of PyTorch tensor shards from the store for a given Tensor Parallel rank, directly into the pre-allocated buffer.
1073+
1074+
```python
1075+
def batch_get_tensor_with_tp_into(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int], tp_rank: int = 0, tp_size: int = 1) -> List[torch.Tensor]
1076+
```
1077+
1078+
**Parameters:**
1079+
1080+
- `base_keys` (List[str]): List of base identifiers.
1081+
- `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered.
1082+
- `sizes` (List[int]): List of the size of buffers.
1083+
- `tp_rank` (int): The tensor parallel rank to retrieve (default: 0).
1084+
- `tp_size` (int): Total tensor parallel size (default: 1).
1085+
1086+
**Returns:**
1087+
1088+
- `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys.
1089+
1090+
---
1091+
10071092
### Batch Zero-Copy Operations
10081093

10091094
#### batch_put_from()

mooncake-integration/integration_utils.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ enum class TensorDtype : int32_t {
3535

3636
template <typename T>
3737
py::array create_typed_array(char *exported_data, size_t offset,
38-
size_t total_length) {
39-
py::capsule free_when_done(
40-
exported_data, [](void *p) { delete[] static_cast<char *>(p); });
38+
size_t total_length, bool take_ownership) {
39+
if (take_ownership) {
40+
py::capsule free_when_done(
41+
exported_data, [](void *p) { delete[] static_cast<char *>(p); });
42+
return py::array_t<T>({static_cast<ssize_t>(total_length / sizeof(T))},
43+
(T *)(exported_data + offset), free_when_done);
44+
}
45+
4146
return py::array_t<T>({static_cast<ssize_t>(total_length / sizeof(T))},
42-
(T *)(exported_data + offset), free_when_done);
47+
(T *)(exported_data + offset), py::none());
4348
}
4449

45-
using ArrayCreatorFunc = std::function<py::array(char *, size_t, size_t)>;
50+
using ArrayCreatorFunc = std::function<py::array(char *, size_t, size_t, bool)>;
4651

4752
static const std::array<ArrayCreatorFunc, 15> array_creators = {{
4853
create_typed_array<float>, // FLOAT32 = 0

0 commit comments

Comments
 (0)