Skip to content

Commit 46ec6d7

Browse files
WoosukKwonnjhill
andauthored
[Model Runner V2] Use a different stream for grammar bitmask h2d copy (vllm-project#33059)
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent e82fa44 commit 46ec6d7

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

vllm/v1/worker/gpu/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
self.structured_outputs_worker = StructuredOutputsWorker(
169169
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
170170
vocab_size=self.vocab_size,
171+
device=self.device,
171172
)
172173
# LoRA-related workers.
173174
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)

vllm/v1/worker/gpu/structured_outputs.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from vllm.triton_utils import tl, triton
77
from vllm.utils.math_utils import cdiv
8-
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
8+
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
99
from vllm.v1.worker.gpu.input_batch import InputBatch
1010

1111

@@ -14,13 +14,16 @@ def __init__(
1414
self,
1515
max_num_logits: int,
1616
vocab_size: int,
17+
device: torch.device,
1718
):
18-
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
19-
# to save a unnecessary CPU-to-CPU copy.
20-
self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
21-
self.grammar_bitmask = UvaBufferPool(
22-
(max_num_logits, cdiv(vocab_size, 32)), torch.int32
19+
self.logits_indices = torch.zeros(
20+
max_num_logits, dtype=torch.int32, device=device
2321
)
22+
self.grammar_bitmask = torch.zeros(
23+
(max_num_logits, cdiv(vocab_size, 32)), dtype=torch.int32, device=device
24+
)
25+
self.device = device
26+
self.copy_stream = torch.cuda.Stream()
2427

2528
def apply_grammar_bitmask(
2629
self,
@@ -32,6 +35,12 @@ def apply_grammar_bitmask(
3235
if not grammar_req_ids:
3336
return
3437

38+
# Asynchronously copy the bitmask to GPU.
39+
with torch.cuda.stream(self.copy_stream):
40+
bitmask = async_copy_to_gpu(
41+
grammar_bitmask, out=self.grammar_bitmask[: grammar_bitmask.shape[0]]
42+
)
43+
3544
# Construct bitmask -> logits mapping
3645
mapping: list[int] = []
3746
req_ids = input_batch.req_ids
@@ -42,12 +51,19 @@ def apply_grammar_bitmask(
4251
logits_start_idx = cu_num_logits[req_idx]
4352
logits_end_idx = cu_num_logits[req_idx + 1]
4453
mapping.extend(range(logits_start_idx, logits_end_idx))
45-
# Copy the mapping.
46-
mapping_np = np.array(mapping, dtype=np.int32)
47-
logits_indices = self.logits_indices.copy_to_uva(mapping_np)
4854

49-
# Copy the bitmask.
50-
bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)
55+
# Asynchronously copy the mapping to GPU.
56+
with torch.cuda.stream(self.copy_stream):
57+
logits_indices = torch.tensor(
58+
mapping, dtype=torch.int32, device="cpu", pin_memory=True
59+
)
60+
logits_indices = self.logits_indices[: len(mapping)].copy_(
61+
logits_indices, non_blocking=True
62+
)
63+
64+
# Ensure all async copies are complete before launching the kernel.
65+
current_stream = torch.cuda.current_stream()
66+
current_stream.wait_stream(self.copy_stream)
5167

5268
num_masks = bitmask.shape[0]
5369
assert num_masks == len(mapping)
@@ -64,6 +80,10 @@ def apply_grammar_bitmask(
6480
BLOCK_SIZE=BLOCK_SIZE,
6581
)
6682

83+
# Ensure the copy stream waits for the device tensors to finish being used
84+
# before it re-uses or deallocates them
85+
self.copy_stream.wait_stream(current_stream)
86+
6787

6888
# Adapted from
6989
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py

0 commit comments

Comments
 (0)