Skip to content

Commit 1df7044

Browse files
authored
fix mtp mem alloc in overlap manner (#1044)
1 parent 6475617 commit 1df7044

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

lightllm/common/mem_manager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3030
self.mem_state = torch.arange(
3131
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
3232
)
33+
self._mem_state_return = torch.arange(
34+
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
35+
)
36+
self._return_start = 0
3337
self.mark_start = 0
3438
self.mark_end = self.size
3539

@@ -250,11 +254,17 @@ def alloc(self, need_size) -> torch.Tensor:
250254

251255
start = self.mark_start
252256
end = self.mark_start + need_size
253-
ans = self.mem_state[start:end]
254257
self.mark_start += need_size
255258

256259
self.can_use_mem_size -= need_size
257260
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
261+
262+
# 利用缓冲区返回,避免异步情况下的内存竞争
263+
if self._return_start + need_size > self._mem_state_return.shape[0]:
264+
self._return_start = 0
265+
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
266+
ans.copy_(self.mem_state[start:end])
267+
self._return_start += need_size
258268
return ans
259269

260270
def free(self, free_index: Union[torch.Tensor, List[int]]):

0 commit comments

Comments
 (0)