Skip to content

Commit bc1aa3b

Browse files
committed
add docs
Signed-off-by: Cruz Zhao <[email protected]>
1 parent 039a0f0 commit bc1aa3b

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

docs/source/python-api-reference/mooncake-store.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,91 @@ def batch_get_tensor_with_tp(self, base_keys: List[str], tp_rank: int = 0, tp_si
10041004

10051005
---
10061006

1007+
### PyTorch Tensor Operations (Zero Copy)
1008+
1009+
These methods provide direct support for storing and retrieving PyTorch tensors. They automatically handle serialization and metadata, and include built-in support for **Tensor Parallelism (TP)** by automatically splitting and reconstructing tensor shards.
1010+
1011+
⚠️ **Note**: These methods require `torch` to be installed and available in the environment.
1012+
1013+
#### get_tensor_into()
1014+
1015+
Get a PyTorch tensor from the store directly into a pre-allocated buffer.
1016+
1017+
```python
1018+
def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int) -> torch.Tensor
1019+
```
1020+
1021+
**Parameters:**
1022+
1023+
- `key` (str): Base identifier of the tensor.
1024+
- `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered.
1025+
- `size` (int): The size of buffer.
1026+
1027+
**Returns:**
1028+
1029+
- `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found.
1030+
1031+
#### batch_get_tensor()
1032+
1033+
Get a batch of PyTorch tensor from the store directly into a pre-allocated buffer.
1034+
1035+
```python
1036+
def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int]) -> List[torch.Tensor]
1037+
```
1038+
1039+
**Parameters:**
1040+
1041+
- `base_keys` (List[str]): List of base identifiers.
1042+
- `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered.
1043+
- `sizes` (List[int]): List of the size of buffers.
1044+
1045+
**Returns:**
1046+
1047+
- `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys.
1048+
1049+
#### get_tensor_into_with_tp()
1050+
1051+
Get a PyTorch tensor from the store, specifically retrieving the shard corresponding to the given Tensor Parallel rank, directly into the pre-allocated buffer.
1052+
1053+
```python
1054+
def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int, tp_rank: int = 0, tp_size: int = 1, split_dim: int = 0) -> torch.Tensor
1055+
```
1056+
1057+
**Parameters:**
1058+
1059+
- `key` (str): Base identifier of the tensor.
1060+
- `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered.
1061+
- `size` (int): The size of buffer.
1062+
- `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). Fetches key `key_tp_{rank}` if `tp_size > 1`.
1063+
- `tp_size` (int): Total tensor parallel size (default: 1).
1064+
- `split_dim` (int): The dimension used during splitting (default: 0).
1065+
1066+
**Returns:**
1067+
1068+
- `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found.
1069+
1070+
#### batch_get_tensor_with_tp()
1071+
1072+
Get a batch of PyTorch tensor shards from the store for a given Tensor Parallel rank, directly into the pre-allocated buffer.
1073+
1074+
```python
1075+
def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int], tp_rank: int = 0, tp_size: int = 1) -> List[torch.Tensor]
1076+
```
1077+
1078+
**Parameters:**
1079+
1080+
- `base_keys` (List[str]): List of base identifiers.
1081+
- `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered.
1082+
- `sizes` (List[int]): List of the size of buffers.
1083+
- `tp_rank` (int): The tensor parallel rank to retrieve (default: 0).
1084+
- `tp_size` (int): Total tensor parallel size (default: 1).
1085+
1086+
**Returns:**
1087+
1088+
- `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys.
1089+
1090+
---
1091+
10071092
### Batch Zero-Copy Operations
10081093

10091094
#### batch_put_from()

0 commit comments

Comments
 (0)