Skip to content

Commit 615fb1b

Browse files
authored
Implement datasystem store to ECMooncakeConnector. (#151)
2 parents b175e40 + bd1959e commit 615fb1b

File tree

2 files changed

+189
-2
lines changed

2 files changed

+189
-2
lines changed

vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
34
from dataclasses import dataclass
5+
from importlib import import_module
46
from typing import TYPE_CHECKING, Optional, Union
57

68
from vllm.config import VllmConfig
79
from vllm.distributed.ec_transfer.ec_connector.base import (
810
ECConnectorBase, ECConnectorMetadata, ECConnectorRole)
9-
from vllm.distributed.ec_transfer.ec_lookup_buffer.mooncake_store import (
10-
ECMooncakeStore)
1111
from vllm.logger import init_logger
1212
from vllm.v1.core.sched.output import SchedulerOutput
1313

14+
_EC_STORE_MODULES = {
15+
"datasystem":
16+
"vllm.distributed.ec_transfer.ec_lookup_buffer.datasystem_store",
17+
"mooncake": "vllm.distributed.ec_transfer.ec_lookup_buffer.mooncake_store"
18+
}
19+
20+
ec_store_type = os.getenv("EC_STORE_TYPE", "mooncake")
21+
module_name = _EC_STORE_MODULES.get(ec_store_type,
22+
_EC_STORE_MODULES["mooncake"])
23+
ECMooncakeStore = import_module(module_name).ECMooncakeStore
24+
1425
if TYPE_CHECKING:
1526
from vllm.v1.request import Request
1627

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Datasystem ECMooncakeStore adaptor.
5+
"""
6+
7+
import json
8+
import os
9+
from queue import Queue
10+
from typing import Optional
11+
12+
import torch
13+
14+
from vllm.config import VllmConfig
15+
from vllm.distributed.parallel_state import get_world_group
16+
from vllm.logger import init_logger
17+
from vllm.utils import split_host_port
18+
19+
logger = init_logger(__name__)
20+
21+
22+
class ECMooncakeStore:
23+
"""
24+
Adaptor for Mooncake storage.
25+
"""
26+
27+
def __init__(self, vllm_config: "VllmConfig") -> None:
28+
try:
29+
from datasystem.ds_tensor_client import DsTensorClient
30+
from datasystem.kv_client import KVClient, SetParam, WriteMode
31+
except ImportError as e:
32+
raise ImportError(
33+
"Please install yuanrong-datasystem at "
34+
"https://gitee.com/openeuler/yuanrong-datasystem "
35+
"to run vLLM with DatasystemStore.") from e
36+
37+
try:
38+
if vllm_config.ec_transfer_config is None:
39+
raise ValueError(
40+
"ec_transfer_config must be set for ECConnectorBase")
41+
42+
ds_worker_addr = os.getenv("DS_WORKER_ADDR", "127.0.0.1:31501")
43+
ip, port = split_host_port(ds_worker_addr)
44+
45+
# Get local rank as device ID
46+
device = get_world_group().local_rank
47+
48+
# Setup parameters
49+
self._set_param = SetParam()
50+
self._set_param.write_mode = WriteMode.NONE_L2_CACHE_EVICT
51+
52+
# Initialize clients
53+
self._ds_tensor_client = DsTensorClient(ip, port, device)
54+
self._ds_tensor_client.init()
55+
56+
self._kv_client = KVClient(ip, port)
57+
self._kv_client.init()
58+
59+
logger.info(
60+
"DatasystemStore initialized successfully. "
61+
"DS_WORKER_ADDR: %s, IP: %s, Port: %d, Device: %d",
62+
ds_worker_addr, ip, port, device)
63+
64+
except Exception as e:
65+
logger.error(
66+
"An error occurred while initializing DatasystemStore: %s", e)
67+
raise
68+
69+
# Queue for handling async put futures
70+
self._put_queue: Queue = Queue()
71+
72+
def batch_get(
73+
self,
74+
keys: list[str],
75+
device: Optional[torch.device] = None
76+
) -> list[Optional[torch.Tensor]]:
77+
"""
78+
Retrieves a batch of tensors from the store.
79+
"""
80+
if not keys:
81+
return []
82+
83+
meta_keys = [f"{key}_meta" for key in keys]
84+
try:
85+
meta_bytes_list = self._kv_client.get(meta_keys)
86+
except Exception as e:
87+
logger.error("batch_get metadata failed: %s", e)
88+
return [None] * len(keys)
89+
90+
tensors: list[Optional[torch.Tensor]] = []
91+
92+
# Pre-allocate empty tensors based on metadata
93+
for meta_bytes_data in meta_bytes_list:
94+
if not meta_bytes_data:
95+
tensors.append(None)
96+
continue
97+
98+
try:
99+
meta = json.loads(meta_bytes_data.decode("utf-8"))
100+
shape = tuple(meta["shape"])
101+
dtype_str = meta["dtype"]
102+
tensors.append(
103+
torch.empty(size=shape,
104+
dtype=getattr(torch, dtype_str),
105+
device=device))
106+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
107+
logger.error("Failed to parse metadata or create tensor: %s",
108+
e)
109+
tensors.append(None)
110+
111+
# Fill the allocated tensors with data from the store
112+
try:
113+
self._ds_tensor_client.mget_h2d(keys, tensors)
114+
except Exception as e:
115+
logger.error("batch_get mget_h2d failed: %s", e)
116+
# Note: Tensors might be partially filled or garbage if this fails
117+
118+
return tensors
119+
120+
def batch_put(self, keys: list[str], tensors: list[torch.Tensor]) -> None:
121+
"""
122+
Stores a batch of tensors asynchronously.
123+
"""
124+
if not keys:
125+
return
126+
127+
meta_bytes_list = []
128+
meta_keys = [f"{key}_meta" for key in keys]
129+
130+
for tensor in tensors:
131+
meta = {
132+
"shape": list(tensor.shape),
133+
# Robustly get dtype string (e.g., 'float16')
134+
"dtype": str(tensor.dtype).split(".")[-1]
135+
}
136+
meta_bytes_list.append(json.dumps(meta).encode("utf-8"))
137+
138+
# 1. Put metadata (Sync)
139+
try:
140+
self._kv_client.mset(meta_keys, meta_bytes_list,
141+
self._set_param.write_mode)
142+
except Exception as e:
143+
logger.error("batch_put metadata mset failed: %s", e)
144+
raise
145+
146+
# 2. Put data (Async init, returns future)
147+
future = self._ds_tensor_client.async_mset_d2h(keys, tensors,
148+
self._set_param)
149+
self._put_queue.put(future)
150+
151+
def wait_for_put(self) -> None:
152+
"""
153+
Waits for all pending put operations to complete.
154+
"""
155+
while not self._put_queue.empty():
156+
future = self._put_queue.get()
157+
try:
158+
# Block until transfer completes
159+
failed_list = future.get()
160+
if failed_list:
161+
logger.error("Async put transfer failed for keys: %s",
162+
failed_list)
163+
except Exception as e:
164+
logger.error("Error waiting for put future: %s", e)
165+
finally:
166+
# If using queue.join(), task_done is needed,
167+
# but standard Queue usage here is fine.
168+
pass
169+
170+
def batch_exists(self, keys: list[str]) -> list[bool]:
171+
"""
172+
Checks if keys exist in the store.
173+
"""
174+
if not keys:
175+
return []
176+
return self._ds_tensor_client.exist(keys)

0 commit comments

Comments
 (0)