55
66from vllm .triton_utils import tl , triton
77from vllm .utils .math_utils import cdiv
8- from vllm .v1 .worker .gpu .buffer_utils import UvaBufferPool
8+ from vllm .v1 .worker .gpu .buffer_utils import async_copy_to_gpu
99from vllm .v1 .worker .gpu .input_batch import InputBatch
1010
1111
@@ -14,13 +14,16 @@ def __init__(
1414 self ,
1515 max_num_logits : int ,
1616 vocab_size : int ,
17+ device : torch .device ,
1718 ):
18- # NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
19- # to save a unnecessary CPU-to-CPU copy.
20- self .logits_indices = UvaBufferPool (max_num_logits , torch .int32 )
21- self .grammar_bitmask = UvaBufferPool (
22- (max_num_logits , cdiv (vocab_size , 32 )), torch .int32
19+ self .logits_indices = torch .zeros (
20+ max_num_logits , dtype = torch .int32 , device = device
2321 )
22+ self .grammar_bitmask = torch .zeros (
23+ (max_num_logits , cdiv (vocab_size , 32 )), dtype = torch .int32 , device = device
24+ )
25+ self .device = device
26+ self .copy_stream = torch .cuda .Stream ()
2427
2528 def apply_grammar_bitmask (
2629 self ,
@@ -32,6 +35,12 @@ def apply_grammar_bitmask(
3235 if not grammar_req_ids :
3336 return
3437
38+ # Asynchronously copy the bitmask to GPU.
39+ with torch .cuda .stream (self .copy_stream ):
40+ bitmask = async_copy_to_gpu (
41+ grammar_bitmask , out = self .grammar_bitmask [: grammar_bitmask .shape [0 ]]
42+ )
43+
3544 # Construct bitmask -> logits mapping
3645 mapping : list [int ] = []
3746 req_ids = input_batch .req_ids
@@ -42,12 +51,19 @@ def apply_grammar_bitmask(
4251 logits_start_idx = cu_num_logits [req_idx ]
4352 logits_end_idx = cu_num_logits [req_idx + 1 ]
4453 mapping .extend (range (logits_start_idx , logits_end_idx ))
45- # Copy the mapping.
46- mapping_np = np .array (mapping , dtype = np .int32 )
47- logits_indices = self .logits_indices .copy_to_uva (mapping_np )
4854
49- # Copy the bitmask.
50- bitmask = self .grammar_bitmask .copy_to_uva (grammar_bitmask )
55+ # Asynchronously copy the mapping to GPU.
56+ with torch .cuda .stream (self .copy_stream ):
57+ logits_indices = torch .tensor (
58+ mapping , dtype = torch .int32 , device = "cpu" , pin_memory = True
59+ )
60+ logits_indices = self .logits_indices [: len (mapping )].copy_ (
61+ logits_indices , non_blocking = True
62+ )
63+
64+ # Ensure all async copies are complete before launching the kernel.
65+ current_stream = torch .cuda .current_stream ()
66+ current_stream .wait_stream (self .copy_stream )
5167
5268 num_masks = bitmask .shape [0 ]
5369 assert num_masks == len (mapping )
@@ -64,6 +80,10 @@ def apply_grammar_bitmask(
6480 BLOCK_SIZE = BLOCK_SIZE ,
6581 )
6682
83+ # Ensure the copy stream waits for the device tensors to finish being used
84+ # before it re-uses or deallocates them
85+ self .copy_stream .wait_stream (current_stream )
86+
6787
6888# Adapted from
6989# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
0 commit comments