Skip to content

Commit 7c1a597

Browse files
committed
improve pin mem manager
1 parent 4157ff4 commit 7c1a597

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,23 @@ def prefill_normal(
103103
next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id)
104104

105105
scatter_token(
106-
next_token_ids,
107-
self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
108-
model_input.b_req_idx,
109-
model_input.b_mtp_index,
106+
next_token_ids=next_token_ids,
107+
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
108+
b_req_idx=model_input.b_req_idx,
109+
b_mtp_index=model_input.b_mtp_index,
110+
b_has_out=g_pin_mem_manager.gen_from_list(
111+
key="b_has_out", data=model_input.b_prefill_has_output_cpu, dtype=torch.bool
112+
).cuda(non_blocking=True),
110113
)
111-
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
112-
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
114+
next_token_ids_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
115+
key="next_token_ids",
116+
gpu_tensor=next_token_ids,
113117
)
114-
next_token_logprobs_cpu = g_pin_mem_manager.alloc_pin_tensor(
115-
"next_token_logprobs", next_token_logprobs.shape[0], next_token_logprobs.dtype
118+
next_token_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
119+
key="next_token_logprobs",
120+
gpu_tensor=next_token_logprobs,
116121
)
117-
next_token_ids_cpu.copy_(next_token_ids, non_blocking=True)
118-
next_token_logprobs_cpu.copy_(next_token_logprobs, non_blocking=True)
122+
119123
sync_event = torch.cuda.Event()
120124
sync_event.record()
121125

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq
6363
b_req_idx = [req.req_idx for req in run_reqs]
6464
b_has_out = [model_input.b_prefill_has_output_cpu[i] for i in batch_idx]
6565

66-
batch_idx = torch.tensor(batch_idx, dtype=torch.int64, device="cpu", pin_memory=True).cuda(
66+
batch_idx = g_pin_mem_manager.gen_from_list(key="batch_idx_", data=batch_idx, dtype=torch.int64).cuda(
6767
non_blocking=True
6868
)
69-
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu", pin_memory=True).cuda(
69+
b_req_idx = g_pin_mem_manager.gen_from_list(key="b_req_idx_", data=b_req_idx, dtype=torch.int32).cuda(
7070
non_blocking=True
7171
)
72-
b_has_out = torch.tensor(b_has_out, dtype=torch.bool, device="cpu", pin_memory=True).cuda(non_blocking=True)
72+
b_has_out = g_pin_mem_manager.gen_from_list(key="b_has_out_", data=b_has_out, dtype=torch.bool).cuda(
73+
non_blocking=True
74+
)
75+
7376
logits = logits[batch_idx]
7477
b_mtp_index = model_input.b_mtp_index[batch_idx]
7578

@@ -83,14 +86,14 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq
8386
b_has_out=b_has_out,
8487
)
8588

86-
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
87-
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
89+
next_token_ids_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
90+
key="next_token_ids",
91+
gpu_tensor=next_token_ids,
8892
)
89-
next_token_logprobs_cpu = g_pin_mem_manager.alloc_pin_tensor(
90-
"next_token_logprobs", next_token_logprobs.shape[0], next_token_logprobs.dtype
93+
next_token_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
94+
key="next_token_logprobs",
95+
gpu_tensor=next_token_logprobs,
9196
)
92-
next_token_ids_cpu.copy_(next_token_ids, non_blocking=True)
93-
next_token_logprobs_cpu.copy_(next_token_logprobs, non_blocking=True)
9497
sync_event = torch.cuda.Event()
9598
sync_event.record()
9699

lightllm/server/router/model_infer/pin_mem_manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self):
1111
self.key_to_alloc_index: Dict[str, int] = {}
1212
self.buffer_size = 4
1313

14-
def alloc_pin_tensor(self, key: str, size: int, dtype: torch.dtype):
14+
def alloc_pin_tensor(self, key: str, size: int, dtype: torch.dtype) -> torch.Tensor:
1515
"""
1616
利用 buffer_size buffer的 pin mem的cache,加速对pin mem的申请和释放操作。
1717
"""
@@ -34,5 +34,17 @@ def alloc_pin_tensor(self, key: str, size: int, dtype: torch.dtype):
3434
self.key_to_alloc_index[key] = (alloc_index + 1) % self.buffer_size
3535
return buff_tensor[0:size]
3636

37+
def gen_from_list(self, key: str, data: List, dtype: torch.dtype) -> torch.Tensor:
38+
size = len(data)
39+
pin_mem = self.alloc_pin_tensor(key, size=size, dtype=dtype)
40+
pin_mem.numpy()[:] = data
41+
return pin_mem
42+
43+
def async_copy_from_gpu_tensor(self, key: str, gpu_tensor: torch.Tensor) -> torch.Tensor:
44+
size = gpu_tensor.numel()
45+
pin_mem = self.alloc_pin_tensor(key, size=size, dtype=gpu_tensor.dtype)
46+
pin_mem.copy_(gpu_tensor.view(-1), non_blocking=True)
47+
return pin_mem.view(gpu_tensor.shape)
48+
3749

3850
g_pin_mem_manager = PinMemTensorManager()

0 commit comments

Comments
 (0)