Skip to content

Commit ca72074

Browse files
committed
fix
1 parent f11a594 commit ca72074

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

lightllm/common/deepseek2_mem_manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
4141
self.kv_move_buffer = torch.empty(
4242
(page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
4343
)
44-
self._buffer_mem_indexes_tensors = [torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) ]
44+
self._buffer_mem_indexes_tensors = [
45+
torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num)
46+
]
4547
return self.kv_move_buffer
4648

4749
def write_mem_to_page_kv_move_buffer(
@@ -53,11 +55,9 @@ def write_mem_to_page_kv_move_buffer(
5355
dp_world_size: int,
5456
):
5557
cur_page = self.kv_move_buffer[page_index]
56-
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0:len(mem_indexes)]
58+
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)]
5759
pin_mem_indexes.numpy()[:] = mem_indexes
58-
mem_indexes_gpu = pin_mem_indexes.cuda(
59-
non_blocking=True
60-
)
60+
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
6161
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
6262
mla_page_io(
6363
mem_indexes=mem_indexes_gpu,
@@ -76,11 +76,9 @@ def read_page_kv_move_buffer_to_mem(
7676
dp_world_size: int,
7777
):
7878
cur_page = self.kv_move_buffer[page_index]
79-
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0:len(mem_indexes)]
79+
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)]
8080
pin_mem_indexes.numpy()[:] = mem_indexes
81-
mem_indexes_gpu = pin_mem_indexes.cuda(
82-
non_blocking=True
83-
)
81+
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
8482
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
8583
for mem in dp_mems:
8684
mla_page_io(

lightllm/common/mem_manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
110110
self.kv_move_buffer = torch.empty(
111111
(page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device="cuda"
112112
)
113-
self._buffer_mem_indexes_tensors = [torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) ]
113+
self._buffer_mem_indexes_tensors = [
114+
torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num)
115+
]
114116
return self.kv_move_buffer
115117

116118
def write_mem_to_page_kv_move_buffer(
@@ -122,11 +124,9 @@ def write_mem_to_page_kv_move_buffer(
122124
dp_world_size: int,
123125
):
124126
cur_page = self.kv_move_buffer[page_index]
125-
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0:len(mem_indexes)]
127+
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)]
126128
pin_mem_indexes.numpy()[:] = mem_indexes
127-
mem_indexes_gpu = pin_mem_indexes.cuda(
128-
non_blocking=True
129-
)
129+
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
130130
repeat_count = dp_world_size * self.kv_buffer.shape[2] // self.kv_move_buffer.shape[3]
131131
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
132132
for tp_index in range(dp_world_size):
@@ -153,11 +153,9 @@ def read_page_kv_move_buffer_to_mem(
153153
dp_world_size: int,
154154
):
155155
cur_page = self.kv_move_buffer[page_index]
156-
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0:len(mem_indexes)]
156+
pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)]
157157
pin_mem_indexes.numpy()[:] = mem_indexes
158-
mem_indexes_gpu = pin_mem_indexes.cuda(
159-
non_blocking=True
160-
)
158+
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
161159
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
162160
mem_indexes_gpu = torch.tensor(mem_indexes, dtype=torch.int64, device="cpu", pin_memory=True).cuda(
163161
non_blocking=True

0 commit comments

Comments
 (0)