@@ -72,37 +72,48 @@ def gather_and_scatter_token_to_cpu(
7272
7373@triton .jit
7474def _fwd_kernel_scatter (
75- token_info ,
76- req_to_token_info ,
75+ next_token_ids ,
76+ req_to_next_token_ids ,
7777 b_req_idx ,
78- req_to_token_info_stride ,
78+ b_mtp_index ,
79+ req_to_next_token_ids_stride ,
80+ req_to_next_token_ids_stride_1 ,
7981):
8082 cur_index = tl .program_id (0 )
8183 cur_req_idx = tl .load (b_req_idx + cur_index )
82- cur_token_info = tl .load (token_info + cur_index )
83- tl .store (req_to_token_info + cur_req_idx * req_to_token_info_stride , cur_token_info )
84+ cur_mtp_index = tl .load (b_mtp_index + cur_index )
85+ cur_next_token_id = tl .load (next_token_ids + cur_index )
86+ tl .store (req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index , cur_next_token_id )
8487 return
8588
8689
8790@torch .no_grad ()
88- def scatter_token (token_info : torch .Tensor , req_to_token_info : torch .Tensor , b_req_idx : torch .Tensor ):
91+ def scatter_token (
92+ next_token_ids : torch .Tensor ,
93+ req_to_next_token_ids : torch .Tensor ,
94+ b_req_idx : torch .Tensor ,
95+ b_mtp_index : torch .Tensor ,
96+ ):
8997 """
9098 This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
9199 Args:
92- token_info : (batch_size, vocab_size )
93- req_to_token_info : (max_req_num,)
100+ next_token_ids : (batch_size,)
101+ req_to_next_token_ids : (max_req_num, max_mtp_step )
94102 b_req_idx: (batch_size,)
103+ b_mtp_index: (batch_size,)
95104 """
96- assert token_info .shape [0 ] == b_req_idx .shape [0 ]
105+ assert next_token_ids .shape [0 ] == b_req_idx .shape [0 ]
97106 batch_size = b_req_idx .shape [0 ]
98107 grid = (batch_size ,)
99108 num_warps = 1
100109
101110 _fwd_kernel_scatter [grid ](
102- token_info ,
103- req_to_token_info ,
111+ next_token_ids ,
112+ req_to_next_token_ids ,
104113 b_req_idx ,
105- req_to_token_info .stride (0 ),
114+ b_mtp_index ,
115+ req_to_next_token_ids .stride (0 ),
116+ req_to_next_token_ids .stride (1 ),
106117 num_warps = num_warps ,
107118 num_stages = 1 ,
108119 )
@@ -111,24 +122,28 @@ def scatter_token(token_info: torch.Tensor, req_to_token_info: torch.Tensor, b_r
111122
112123@triton .jit
113124def _fwd_kernel_gather (
114- req_to_token_info ,
115- req_to_token_info_stride ,
125+ req_to_next_token_ids ,
126+ req_to_next_token_ids_stride ,
127+ req_to_next_token_ids_stride_1 ,
116128 output ,
117129 b_req_idx ,
130+ b_mtp_index ,
118131):
119132 cur_index = tl .program_id (0 )
120133 cur_req_idx = tl .load (b_req_idx + cur_index )
121- cur_token_info = tl .load (req_to_token_info + cur_req_idx * req_to_token_info_stride )
122- tl .store (output + cur_index , cur_token_info )
134+ cur_mtp_index = tl .load (b_mtp_index + cur_index )
135+ cur_next_token_id = tl .load (req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index )
136+ tl .store (output + cur_index , cur_next_token_id )
123137 return
124138
125139
126- def gather_token (req_to_token_info : torch .Tensor , b_req_idx : torch .Tensor ):
140+ def gather_token (req_to_next_token_ids : torch .Tensor , b_req_idx : torch . Tensor , b_mtp_index : torch .Tensor ):
127141 """
128142 This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
129143 Args:
130144 req_to_token_info: (max_req_num, max_mtp_step)
131145 b_req_idx: (batch_size,)
146+ b_mtp_index: (batch_size,)
132147 Returns:
133148 output: (batch_size,)
134149 """
@@ -137,10 +152,12 @@ def gather_token(req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
137152 grid = (batch_size ,)
138153 num_warps = 1
139154 _fwd_kernel_gather [grid ](
140- req_to_token_info ,
141- req_to_token_info .stride (0 ),
155+ req_to_next_token_ids ,
156+ req_to_next_token_ids .stride (0 ),
157+ req_to_next_token_ids .stride (1 ),
142158 output ,
143159 b_req_idx ,
160+ b_mtp_index ,
144161 num_warps = num_warps ,
145162 num_stages = 1 ,
146163 )
@@ -187,7 +204,8 @@ def test_scatter_token_to_cpu():
187204 req_to_token_info = torch .zeros ((1000 ,), dtype = torch .float32 , pin_memory = True )
188205 token_info = torch .randn ((batch_size ,)).cuda ()
189206 req_ids = torch .arange (20 , 20 + batch_size , dtype = torch .int32 ).cuda ()
190- scatter_token (token_info , req_to_token_info , req_ids )
207+ mtp_index = torch .zeros ((batch_size ,), dtype = torch .int32 ).cuda ()
208+ scatter_token (token_info , req_to_token_info , req_ids , mtp_index )
191209 diff = (req_to_token_info [20 : 20 + batch_size ].cuda () - token_info ).abs ().max ()
192210 assert diff < 1e-6
193211 print ("test_scatter_token_to_cpu passed" )
@@ -198,8 +216,9 @@ def test_gather_token():
198216 req_to_token_info = torch .zeros ((1000 ,), dtype = torch .int32 , pin_memory = True )
199217 token_info = torch .randn ((batch_size ,)).cuda ()
200218 req_ids = torch .arange (20 , 20 + batch_size , dtype = torch .int32 ).cuda ()
201- scatter_token (token_info , req_to_token_info , req_ids )
202- output = gather_token (req_to_token_info , req_ids )
219+ mtp_index = torch .zeros ((batch_size ,), dtype = torch .int32 ).cuda ()
220+ scatter_token (token_info , req_to_token_info , req_ids , mtp_index )
221+ output = gather_token (req_to_token_info , req_ids , mtp_index )
203222 diff = (token_info - output ).abs ().max ()
204223 assert diff < 1e-6
205224 print ("test_gather_token passed" )
0 commit comments