44import triton .language as tl
55
66
7- @triton .jit
8- def _fwd_kernel_gather_and_scatter (
9- probs_idx ,
10- probs_sort ,
11- req_to_next_token_ids ,
12- req_to_next_token_probs ,
13- sampled_index ,
14- b_req_idx ,
15- probs_idx_stride ,
16- probs_sort_stride ,
17- req_to_next_token_ids_stride ,
18- req_to_next_token_probs_stride ,
19- ):
20- cur_index = tl .program_id (0 )
21- cur_req_idx = tl .load (b_req_idx + cur_index )
22- cur_sampled_index = tl .load (sampled_index + cur_index )
23- cur_token_index = tl .load (probs_idx + cur_index * probs_idx_stride + cur_sampled_index )
24- cur_token_probs = tl .load (probs_sort + cur_index * probs_sort_stride + cur_sampled_index )
25- tl .store (req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride , cur_token_index )
26- tl .store (req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride , tl .log (cur_token_probs ))
27- return
28-
29-
30- @torch .no_grad ()
31- def gather_and_scatter_token_to_cpu (
32- probs_idx : torch .Tensor ,
33- probs_sort : torch .Tensor ,
34- req_to_next_token_ids : torch .Tensor ,
35- req_to_next_token_probs : torch .Tensor ,
36- sampled_index : torch .Tensor ,
37- b_req_idx : torch .Tensor ,
38- ):
39- """
40- This function is used to gather the next_token_id(GPU tensor) and next_token_probs(GPU tensor)
41- info to the req_to_next_token_ids and req_to_next_token_probs(CPU tensor).
42- Args:
43- probs_idx: (batch_size, vocab_size)
44- probs_sort: (batch_size, vocab_size)
45- req_to_next_token_ids: (max_req_num,)
46- req_to_next_token_probs: (max_req_num,)
47- sampled_index: (batch_size,)
48- b_req_idx: (batch_size,)
49- """
50- assert probs_idx .shape == probs_sort .shape
51- assert sampled_index .shape [0 ] == b_req_idx .shape [0 ]
52- batch_size = b_req_idx .shape [0 ]
53- grid = (batch_size ,)
54- num_warps = 1
55-
56- _fwd_kernel_gather_and_scatter [grid ](
57- probs_idx ,
58- probs_sort ,
59- req_to_next_token_ids ,
60- req_to_next_token_probs ,
61- sampled_index ,
62- b_req_idx ,
63- probs_idx .stride (0 ),
64- probs_sort .stride (0 ),
65- req_to_next_token_ids .stride (0 ),
66- req_to_next_token_probs .stride (0 ),
67- num_warps = num_warps ,
68- num_stages = 1 ,
69- )
70- return
71-
72-
737@triton .jit
748def _fwd_kernel_scatter (
759 next_token_ids ,
7610 req_to_next_token_ids ,
7711 b_req_idx ,
7812 b_mtp_index ,
13+ b_has_out ,
7914 req_to_next_token_ids_stride ,
8015 req_to_next_token_ids_stride_1 ,
16+ num_size ,
17+ HAS_OUT_IS_NONE : tl .constexpr ,
18+ BLOCK : tl .constexpr ,
8119):
82- cur_index = tl .program_id (0 )
83- cur_req_idx = tl .load (b_req_idx + cur_index )
84- cur_mtp_index = tl .load (b_mtp_index + cur_index )
85- cur_next_token_id = tl .load (next_token_ids + cur_index )
86- tl .store (req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index , cur_next_token_id )
20+ block_index = tl .program_id (0 )
21+ block_range = block_index * BLOCK + tl .arange (0 , BLOCK )
22+ block_mask = block_range < num_size
23+
24+ cur_req_idx = tl .load (b_req_idx + block_range , mask = block_mask )
25+ cur_mtp_index = tl .load (b_mtp_index + block_range , mask = block_mask )
26+ cur_next_token_id = tl .load (next_token_ids + block_range , mask = block_mask )
27+
28+ if not HAS_OUT_IS_NONE :
29+ cur_has_out = tl .load (b_has_out + block_range , mask = block_mask , other = False )
30+ tl .store (
31+ req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index ,
32+ cur_next_token_id ,
33+ mask = cur_has_out & block_mask ,
34+ )
35+ else :
36+ tl .store (
37+ req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index ,
38+ cur_next_token_id ,
39+ mask = block_mask ,
40+ )
41+
8742 return
8843
8944
@@ -93,6 +48,7 @@ def scatter_token(
9348 req_to_next_token_ids : torch .Tensor ,
9449 b_req_idx : torch .Tensor ,
9550 b_mtp_index : torch .Tensor ,
51+ b_has_out : torch .Tensor = None ,
9652):
9753 """
9854 This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
@@ -104,16 +60,22 @@ def scatter_token(
10460 """
10561 assert next_token_ids .shape [0 ] == b_req_idx .shape [0 ]
10662 batch_size = b_req_idx .shape [0 ]
107- grid = (batch_size ,)
63+ BLOCK = 256
64+
65+ grid = (triton .cdiv (batch_size , BLOCK ),)
10866 num_warps = 1
10967
11068 _fwd_kernel_scatter [grid ](
111- next_token_ids ,
112- req_to_next_token_ids ,
113- b_req_idx ,
114- b_mtp_index ,
115- req_to_next_token_ids .stride (0 ),
116- req_to_next_token_ids .stride (1 ),
69+ next_token_ids = next_token_ids ,
70+ req_to_next_token_ids = req_to_next_token_ids ,
71+ b_req_idx = b_req_idx ,
72+ b_mtp_index = b_mtp_index ,
73+ b_has_out = b_has_out ,
74+ req_to_next_token_ids_stride = req_to_next_token_ids .stride (0 ),
75+ req_to_next_token_ids_stride_1 = req_to_next_token_ids .stride (1 ),
76+ num_size = batch_size ,
77+ HAS_OUT_IS_NONE = b_has_out is None ,
78+ BLOCK = BLOCK ,
11779 num_warps = num_warps ,
11880 num_stages = 1 ,
11981 )
@@ -128,12 +90,18 @@ def _fwd_kernel_gather(
12890 output ,
12991 b_req_idx ,
13092 b_mtp_index ,
93+ num_size ,
94+ BLOCK : tl .constexpr ,
13195):
132- cur_index = tl .program_id (0 )
133- cur_req_idx = tl .load (b_req_idx + cur_index )
134- cur_mtp_index = tl .load (b_mtp_index + cur_index )
135- cur_next_token_id = tl .load (req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index )
136- tl .store (output + cur_index , cur_next_token_id )
96+ block_index = tl .program_id (0 )
97+ block_range = block_index * BLOCK + tl .arange (0 , BLOCK )
98+ block_mask = block_range < num_size
99+ cur_req_idx = tl .load (b_req_idx + block_range , mask = block_mask )
100+ cur_mtp_index = tl .load (b_mtp_index + block_range , mask = block_mask )
101+ cur_next_token_id = tl .load (
102+ req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index , mask = block_mask
103+ )
104+ tl .store (output + block_range , cur_next_token_id , mask = block_mask )
137105 return
138106
139107
@@ -148,72 +116,40 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b
148116 output: (batch_size,)
149117 """
150118 batch_size = b_req_idx .shape [0 ]
151- output = torch .empty_like (b_req_idx )
152- grid = (batch_size ,)
119+ output = torch .empty (batch_size , dtype = req_to_next_token_ids .dtype , device = "cuda" )
120+ BLOCK = 256
121+ grid = (triton .cdiv (batch_size , BLOCK ),)
153122 num_warps = 1
154123 _fwd_kernel_gather [grid ](
155- req_to_next_token_ids ,
156- req_to_next_token_ids .stride (0 ),
157- req_to_next_token_ids .stride (1 ),
158- output ,
159- b_req_idx ,
160- b_mtp_index ,
124+ req_to_next_token_ids = req_to_next_token_ids ,
125+ req_to_next_token_ids_stride = req_to_next_token_ids .stride (0 ),
126+ req_to_next_token_ids_stride_1 = req_to_next_token_ids .stride (1 ),
127+ output = output ,
128+ b_req_idx = b_req_idx ,
129+ b_mtp_index = b_mtp_index ,
130+ num_size = batch_size ,
131+ BLOCK = BLOCK ,
161132 num_warps = num_warps ,
162133 num_stages = 1 ,
163134 )
164135 return output
165136
166137
167- def _top_p_top_k (probs : torch .Tensor , top_ps : torch .Tensor , top_ks : torch .Tensor ):
168- probs_sort , probs_idx = probs .sort (dim = - 1 , descending = True )
169-
170- probs_sum = torch .cumsum (probs_sort , dim = - 1 )
171- probs_sort [(probs_sum - probs_sort ) > top_ps .view (- 1 , 1 )] = 0.0
172-
173- probs_sort [torch .arange (0 , probs .shape [- 1 ], device = "cuda" ).view (1 , - 1 ) >= top_ks .view (- 1 , 1 )] = 0.0
174-
175- return probs_sort , probs_idx
176-
177-
178- def test_gather_and_scatter_token_to_cpu ():
179- batch_size = 30
180- vocab_size = 60000
181- req_to_next_token_ids = torch .ones ((1000 ,), dtype = torch .int32 , pin_memory = True )
182- req_to_next_token_probs = torch .ones ((1000 ,), dtype = torch .float32 , pin_memory = True )
183- req_ids = torch .arange (20 , 20 + batch_size , dtype = torch .int32 ).cuda ()
184- probs = torch .randn ((batch_size , vocab_size )).cuda ()
185- top_ps = torch .rand ((batch_size ,)).cuda ()
186- top_ks = torch .ones ((batch_size ,), dtype = torch .int32 ).cuda ()
187- probs_sort , probs_idx = _top_p_top_k (probs , top_ps , top_ks )
188- sampled_index = torch .multinomial (probs_sort , num_samples = 1 , replacement = True )
189- batch_next_token_ids = torch .gather (probs_idx , dim = 1 , index = sampled_index )
190- batch_next_token_probs = torch .gather (probs_sort , dim = 1 , index = sampled_index )
191-
192- gather_and_scatter_token_to_cpu (
193- probs_idx , probs_sort , req_to_next_token_ids , req_to_next_token_probs , sampled_index , req_ids
194- )
195- diff_ids = (req_to_next_token_ids [20 : 20 + batch_size ].cuda () - batch_next_token_ids .view (- 1 )).abs ().max ()
196- diff_probs = (req_to_next_token_probs [20 : 20 + batch_size ].cuda () - batch_next_token_probs .view (- 1 )).abs ().max ()
197- assert diff_ids < 1e-6
198- assert diff_probs < 1e-6
199- print ("test_gather_and_scatter_token_to_cpu passed" )
200-
201-
202138def test_scatter_token_to_cpu ():
203139 batch_size = 30
204- req_to_token_info = torch .zeros ((1000 ,), dtype = torch .float32 , pin_memory = True )
140+ req_to_token_info = torch .zeros ((1000 , 1 ), dtype = torch .float32 , pin_memory = True )
205141 token_info = torch .randn ((batch_size ,)).cuda ()
206142 req_ids = torch .arange (20 , 20 + batch_size , dtype = torch .int32 ).cuda ()
207143 mtp_index = torch .zeros ((batch_size ,), dtype = torch .int32 ).cuda ()
208144 scatter_token (token_info , req_to_token_info , req_ids , mtp_index )
209- diff = (req_to_token_info [20 : 20 + batch_size ].cuda () - token_info ).abs ().max ()
145+ diff = (req_to_token_info [20 : 20 + batch_size ].cuda (). view ( - 1 ) - token_info ).abs ().max ()
210146 assert diff < 1e-6
211147 print ("test_scatter_token_to_cpu passed" )
212148
213149
214150def test_gather_token ():
215151 batch_size = 30
216- req_to_token_info = torch .zeros ((1000 ,), dtype = torch .int32 , pin_memory = True )
152+ req_to_token_info = torch .zeros ((1000 , 1 ), dtype = torch .float32 , pin_memory = True )
217153 token_info = torch .randn ((batch_size ,)).cuda ()
218154 req_ids = torch .arange (20 , 20 + batch_size , dtype = torch .int32 ).cuda ()
219155 mtp_index = torch .zeros ((batch_size ,), dtype = torch .int32 ).cuda ()
@@ -225,6 +161,5 @@ def test_gather_token():
225161
226162
227163if __name__ == "__main__" :
228- test_gather_and_scatter_token_to_cpu ()
229164 test_scatter_token_to_cpu ()
230165 test_gather_token ()
0 commit comments