Skip to content

Commit 65a840e

Browse files
author
wangzaijun
committed
fix
1 parent 8c3c44c commit 65a840e

File tree

21 files changed

+461
-288
lines changed

21 files changed

+461
-288
lines changed

lightllm/common/deepseek2_mem_manager.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
4242
(page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
4343
)
4444
return self.kv_move_buffer
45-
46-
def write_mem_to_page_kv_move_buffer(self,
47-
mem_indexes: List[int],
48-
page_index: int,
49-
dp_index: int,
50-
mem_managers: List["MemoryManager"],
51-
dp_world_size:int):
45+
46+
def write_mem_to_page_kv_move_buffer(
47+
self,
48+
mem_indexes: List[int],
49+
page_index: int,
50+
dp_index: int,
51+
mem_managers: List["MemoryManager"],
52+
dp_world_size: int,
53+
):
5254
cur_page = self.kv_move_buffer[page_index]
5355
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
5456
mla_page_io(
@@ -58,13 +60,15 @@ def write_mem_to_page_kv_move_buffer(self,
5860
mode="write",
5961
)
6062
return
61-
62-
def read_page_kv_move_buffer_to_mem(self,
63-
mem_indexes: List[int],
64-
page_index: int,
65-
dp_index: int,
66-
mem_managers: List["MemoryManager"],
67-
dp_world_size:int):
63+
64+
def read_page_kv_move_buffer_to_mem(
65+
self,
66+
mem_indexes: List[int],
67+
page_index: int,
68+
dp_index: int,
69+
mem_managers: List["MemoryManager"],
70+
dp_world_size: int,
71+
):
6872
cur_page = self.kv_move_buffer[page_index]
6973
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
7074
mem_indexes = torch.tensor(mem_indexes, dtype=torch.int64, device="cuda")
@@ -76,7 +80,6 @@ def read_page_kv_move_buffer_to_mem(self,
7680
mode="read",
7781
)
7882

79-
8083
def send_to_decode_node(
8184
self,
8285
move_tasks: List[KVMoveTask],

lightllm/common/kv_trans_kernel/nixl_kv_trans.py

Lines changed: 120 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
logger = init_logger(__name__)
88

9+
910
@triton.jit
1011
def _page_io(
1112
mem_index_ptr,
@@ -43,7 +44,7 @@ def _page_io(
4344
v_stride_layer_num = tl.cast(v_stride_layer_num, dtype=tl.int64)
4445
k_stride_size = tl.cast(k_stride_size, dtype=tl.int64)
4546
v_stride_size = tl.cast(v_stride_size, dtype=tl.int64)
46-
47+
4748
tid = tl.program_id(0)
4849
kv_head_id = tl.program_id(1)
4950
page_head_id = page_head_start + kv_head_id
@@ -57,18 +58,86 @@ def _page_io(
5758

5859
for layer_index in tl.range(layer_num, num_stages=3):
5960
if IS_WRITE:
60-
k_tensor = tl.load(k_ptr + layer_index * k_stride_layer_num + mem_index * k_stride_size + kv_head_id * k_stride_head + off_dim * k_stride_dim, mask=mask)
61-
v_tensor = tl.load(v_ptr + layer_index * v_stride_layer_num + mem_index * v_stride_size + kv_head_id * v_stride_head + off_dim * v_stride_dim, mask=mask)
62-
tl.store(k_page_ptr + tid * k_page_stride_size + layer_index * k_page_stride_layer_num + page_head_id * k_page_stride_head + off_dim * k_page_stride_dim, k_tensor, mask=mask)
63-
tl.store(v_page_ptr + tid * v_page_stride_size + layer_index * v_page_stride_layer_num + page_head_id * v_page_stride_head + off_dim * v_page_stride_dim, v_tensor, mask=mask)
61+
k_tensor = tl.load(
62+
k_ptr
63+
+ layer_index * k_stride_layer_num
64+
+ mem_index * k_stride_size
65+
+ kv_head_id * k_stride_head
66+
+ off_dim * k_stride_dim,
67+
mask=mask,
68+
)
69+
v_tensor = tl.load(
70+
v_ptr
71+
+ layer_index * v_stride_layer_num
72+
+ mem_index * v_stride_size
73+
+ kv_head_id * v_stride_head
74+
+ off_dim * v_stride_dim,
75+
mask=mask,
76+
)
77+
tl.store(
78+
k_page_ptr
79+
+ tid * k_page_stride_size
80+
+ layer_index * k_page_stride_layer_num
81+
+ page_head_id * k_page_stride_head
82+
+ off_dim * k_page_stride_dim,
83+
k_tensor,
84+
mask=mask,
85+
)
86+
tl.store(
87+
v_page_ptr
88+
+ tid * v_page_stride_size
89+
+ layer_index * v_page_stride_layer_num
90+
+ page_head_id * v_page_stride_head
91+
+ off_dim * v_page_stride_dim,
92+
v_tensor,
93+
mask=mask,
94+
)
6495
else:
65-
k_page_tensor = tl.load(k_page_ptr + tid * k_page_stride_size + layer_index * k_page_stride_layer_num + page_head_id * k_page_stride_head + off_dim * k_page_stride_dim, mask=mask)
66-
v_page_tensor = tl.load(v_page_ptr + tid * v_page_stride_size + layer_index * v_page_stride_layer_num + page_head_id * v_page_stride_head + off_dim * v_page_stride_dim, mask=mask)
67-
tl.store(k_ptr + layer_index * k_stride_layer_num + mem_index * k_stride_size + kv_head_id * k_stride_head + off_dim * k_stride_dim, k_page_tensor, mask=mask)
68-
tl.store(v_ptr + layer_index * v_stride_layer_num + mem_index * v_stride_size + kv_head_id * v_stride_head + off_dim * v_stride_dim, v_page_tensor, mask=mask)
96+
k_page_tensor = tl.load(
97+
k_page_ptr
98+
+ tid * k_page_stride_size
99+
+ layer_index * k_page_stride_layer_num
100+
+ page_head_id * k_page_stride_head
101+
+ off_dim * k_page_stride_dim,
102+
mask=mask,
103+
)
104+
v_page_tensor = tl.load(
105+
v_page_ptr
106+
+ tid * v_page_stride_size
107+
+ layer_index * v_page_stride_layer_num
108+
+ page_head_id * v_page_stride_head
109+
+ off_dim * v_page_stride_dim,
110+
mask=mask,
111+
)
112+
tl.store(
113+
k_ptr
114+
+ layer_index * k_stride_layer_num
115+
+ mem_index * k_stride_size
116+
+ kv_head_id * k_stride_head
117+
+ off_dim * k_stride_dim,
118+
k_page_tensor,
119+
mask=mask,
120+
)
121+
tl.store(
122+
v_ptr
123+
+ layer_index * v_stride_layer_num
124+
+ mem_index * v_stride_size
125+
+ kv_head_id * v_stride_head
126+
+ off_dim * v_stride_dim,
127+
v_page_tensor,
128+
mask=mask,
129+
)
69130
return
70131

71-
def page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torch.Tensor, tp_index:int, tp_world_size:int, mode:str):
132+
133+
def page_io(
134+
mem_indexes: torch.Tensor,
135+
page_tensor: torch.Tensor,
136+
kv_buffer: torch.Tensor,
137+
tp_index: int,
138+
tp_world_size: int,
139+
mode: str,
140+
):
72141
assert mode in ["read", "write"]
73142
assert mem_indexes.is_contiguous()
74143
assert page_tensor.is_contiguous()
@@ -86,9 +155,10 @@ def page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torc
86155
v_page_tensor = page_tensor[:, :, -page_v_head_num:, :]
87156

88157
k_head_num, v_head_num = kv_head_num // 2, kv_head_num // 2
158+
assert k_head_num == v_head_num
89159
k_buffer = kv_buffer[:, :, 0:k_head_num, :]
90160
v_buffer = kv_buffer[:, :, k_head_num:, :]
91-
161+
92162
tp_index = tp_index // repeat_count
93163
tp_world_size = tp_world_size // repeat_count
94164

@@ -127,14 +197,13 @@ def page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torc
127197
layer_num=layer_num,
128198
head_dim=head_dim,
129199
HEAD_DIM_BLOCK=triton.next_power_of_2(head_dim),
130-
IS_WRITE=mode=="write",
200+
IS_WRITE=mode == "write",
131201
NEED_MASK=triton.next_power_of_2(head_dim) != head_dim,
132202
num_warps=1,
133203
)
134204
return
135205

136206

137-
138207
@triton.jit
139208
def _mla_page_io(
140209
mem_index_ptr,
@@ -157,7 +226,7 @@ def _mla_page_io(
157226
page_stride_size = tl.cast(page_stride_size, dtype=tl.int64)
158227
kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64)
159228
kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64)
160-
229+
161230
tid = tl.program_id(0)
162231

163232
mem_index = tl.load(mem_index_ptr + tid)
@@ -169,14 +238,45 @@ def _mla_page_io(
169238

170239
for layer_index in tl.range(layer_num, num_stages=3):
171240
if IS_WRITE:
172-
kv_tensor = tl.load(kv_ptr + layer_index * kv_stride_layer_num + mem_index * kv_stride_size + 0 * kv_stride_head + off_dim * kv_stride_dim, mask=mask)
173-
tl.store(page_ptr + tid * page_stride_size + layer_index * page_stride_layer_num + 0 * page_stride_head + off_dim * page_stride_dim, kv_tensor, mask=mask)
241+
kv_tensor = tl.load(
242+
kv_ptr
243+
+ layer_index * kv_stride_layer_num
244+
+ mem_index * kv_stride_size
245+
+ 0 * kv_stride_head
246+
+ off_dim * kv_stride_dim,
247+
mask=mask,
248+
)
249+
tl.store(
250+
page_ptr
251+
+ tid * page_stride_size
252+
+ layer_index * page_stride_layer_num
253+
+ 0 * page_stride_head
254+
+ off_dim * page_stride_dim,
255+
kv_tensor,
256+
mask=mask,
257+
)
174258
else:
175-
page_tensor = tl.load(page_ptr + tid * page_stride_size + layer_index * page_stride_layer_num + 0 * page_stride_head + off_dim * page_stride_dim, mask=mask)
176-
tl.store(kv_ptr + layer_index * kv_stride_layer_num + mem_index * kv_stride_size + 0 * kv_stride_head + off_dim * kv_stride_dim, page_tensor, mask=mask)
259+
page_tensor = tl.load(
260+
page_ptr
261+
+ tid * page_stride_size
262+
+ layer_index * page_stride_layer_num
263+
+ 0 * page_stride_head
264+
+ off_dim * page_stride_dim,
265+
mask=mask,
266+
)
267+
tl.store(
268+
kv_ptr
269+
+ layer_index * kv_stride_layer_num
270+
+ mem_index * kv_stride_size
271+
+ 0 * kv_stride_head
272+
+ off_dim * kv_stride_dim,
273+
page_tensor,
274+
mask=mask,
275+
)
177276
return
178277

179-
def mla_page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torch.Tensor, mode:str):
278+
279+
def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torch.Tensor, mode: str):
180280
assert mode in ["read", "write"]
181281
assert mem_indexes.is_contiguous()
182282
assert page_tensor.is_contiguous()
@@ -189,7 +289,6 @@ def mla_page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
189289
assert page_head_dim == head_dim
190290
assert page_head_num == kv_head_num == 1
191291

192-
193292
token_num = len(mem_indexes)
194293
grid = (token_num,)
195294

@@ -208,7 +307,7 @@ def mla_page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
208307
layer_num=layer_num,
209308
head_dim=head_dim,
210309
HEAD_DIM_BLOCK=triton.next_power_of_2(head_dim),
211-
IS_WRITE=mode=="write",
310+
IS_WRITE=mode == "write",
212311
NEED_MASK=triton.next_power_of_2(head_dim) != head_dim,
213312
num_warps=1,
214313
)

lightllm/common/mem_manager.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,55 +105,62 @@ def alloc_kv_move_buffer(self, max_req_total_len):
105105
def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
106106
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
107107
raise NotImplementedError("subclass need reimpl this method")
108-
108+
109109
num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir)
110110
self.kv_move_buffer = torch.empty(
111111
(page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device="cuda"
112112
)
113113
return self.kv_move_buffer
114-
115-
def write_mem_to_page_kv_move_buffer(self,
116-
mem_indexes: List[int],
117-
page_index: int,
118-
dp_index: int,
119-
mem_managers: List["MemoryManager"],
120-
dp_world_size:int):
114+
115+
def write_mem_to_page_kv_move_buffer(
116+
self,
117+
mem_indexes: List[int],
118+
page_index: int,
119+
dp_index: int,
120+
mem_managers: List["MemoryManager"],
121+
dp_world_size: int,
122+
):
121123
cur_page = self.kv_move_buffer[page_index]
122124
repeat_count = dp_world_size * self.kv_buffer.shape[2] // self.kv_move_buffer.shape[3]
123125
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
124126
for tp_index in range(dp_world_size):
125127
if tp_index % repeat_count == 0:
126-
page_io(torch.tensor(mem_indexes, dtype=torch.int64, device="cuda"),
127-
page_tensor=cur_page,
128-
kv_buffer=dp_mems[tp_index].kv_buffer,
129-
tp_index=tp_index,
130-
tp_world_size=dp_world_size,
131-
mode="write")
128+
page_io(
129+
torch.tensor(mem_indexes, dtype=torch.int64, device="cuda"),
130+
page_tensor=cur_page,
131+
kv_buffer=dp_mems[tp_index].kv_buffer,
132+
tp_index=tp_index,
133+
tp_world_size=dp_world_size,
134+
mode="write",
135+
)
132136
# keep for debug
133137
# logger.info(f"src token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}")
134138
# logger.info(f"src page token tensor {cur_page[0, :, 0, 0]}")
135139
return
136-
137-
def read_page_kv_move_buffer_to_mem(self,
138-
mem_indexes: List[int],
139-
page_index: int,
140-
dp_index: int,
141-
mem_managers: List["MemoryManager"],
142-
dp_world_size:int):
140+
141+
def read_page_kv_move_buffer_to_mem(
142+
self,
143+
mem_indexes: List[int],
144+
page_index: int,
145+
dp_index: int,
146+
mem_managers: List["MemoryManager"],
147+
dp_world_size: int,
148+
):
143149
cur_page = self.kv_move_buffer[page_index]
144150
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
145151
for tp_index in range(dp_world_size):
146-
page_io(torch.tensor(mem_indexes, dtype=torch.int64, device="cuda"),
147-
page_tensor=cur_page,
148-
kv_buffer=dp_mems[tp_index].kv_buffer,
149-
tp_index=tp_index,
150-
tp_world_size=dp_world_size,
151-
mode="read")
152+
page_io(
153+
torch.tensor(mem_indexes, dtype=torch.int64, device="cuda"),
154+
page_tensor=cur_page,
155+
kv_buffer=dp_mems[tp_index].kv_buffer,
156+
tp_index=tp_index,
157+
tp_world_size=dp_world_size,
158+
mode="read",
159+
)
152160
# keep for debug
153161
# logger.info(f"dst token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}")
154162
# logger.info(f"dst page token tensor {cur_page[0, :, 0, 0]}")
155163

156-
157164
def send_to_decode_node(
158165
self,
159166
move_tasks: List[KVMoveTask],

lightllm/server/core/objs/req.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def get_all_prompt_metadata(self):
254254
metadata["prompt_token_ids"] = [int(e) for e in cur_ids]
255255
self._cache_prompt_metadata = metadata
256256
return metadata
257-
257+
258258
def is_infer_decode(self) -> bool:
259259
"""
260260
judge the req is in decode stage

lightllm/server/core/objs/sampling_params.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def initialize(self, inputs: Tuple[int, float]):
222222

223223
def to_tuple(self):
224224
return (self.item0, self.item1)
225-
225+
226+
226227
class NodeUUId(ctypes.Structure):
227228
_pack_ = 4
228229
_fields_ = [
@@ -236,7 +237,7 @@ def initialize(self, node_id: int):
236237
return
237238

238239
def get(self) -> int:
239-
return ((self.node_id_high << 64) | self.node_id_low)
240+
return (self.node_id_high << 64) | self.node_id_low
240241

241242

242243
class DecodeNode(ctypes.Structure):
@@ -308,7 +309,7 @@ class SamplingParams(ctypes.Structure):
308309
("group_request_id", ctypes.c_int64), # p d mode used params
309310
("suggested_dp_index", ctypes.c_int), # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index
310311
("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode
311-
# in pd split mode, use to keep the id of pd master
312+
# in pd split mode, use to keep the id of pd master
312313
("pd_master_node_id", NodeUUId),
313314
# nixl params object, only used in nixl pd mode, used to build nixl connection in p and d
314315
("nixl_params", NIXLParamObj),

0 commit comments

Comments
 (0)