Skip to content

Commit 75c0935

Browse files
author
wangzaijun
committed
fix
1 parent 2d650fd commit 75c0935

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightllm/common/mem_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,11 @@ def copy_kv_from_other_dp_ranks(
420420
mems_ptr_list = []
421421
for i in range(0, len(mem_managers)):
422422
mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr())
423-
self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cuda")
423+
self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True)
424424

425425
# 一次性传输所有层
426426
kv_trans_for_dp(
427-
input_mems=self.mem_ptrs_tensor,
427+
input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True),
428428
input_idx=move_token_indexes,
429429
input_dp_idx=token_dp_indexes,
430430
output=self.kv_buffer,

0 commit comments

Comments
 (0)