@@ -63,27 +63,24 @@ __device__ void copyChunkedHiddenStates(T const* srcPtr, T* dstPtr, int const nu
6363}
6464
6565template <typename T>
66- __global__ void mtpPrepareDrafterInputsKernel (int const numMTPModules, int const curMTPLayerIdx, int const batchSize ,
67- int const numContextRequest , int const hiddenSize , int const * inputIds, int const * seqLens ,
68- T ** const mtpPastHiddenStatesPtrs, int ** const mtpPastTokensPtrs, T* const previousLayerHiddenStates ,
69- int * const previousLayerDraftTokens, int * returnInputIds, T* returnHiddenStates)
66+ __global__ void mtpPrepareDrafterInputsKernel (int const numMTPModules, int const numContextRequest ,
67+ int const hiddenSize , int const * inputIds , int const * seqLens, T** const mtpPastHiddenStatesPtrs ,
68+ int ** const mtpPastTokensPtrs, T* const hiddenStates, int const * acceptedTokens, int const * numAcceptedTokens ,
69+ int * returnInputIds, T* returnHiddenStates)
7070{
7171 /*
7272 In a batch of request: context request (at the beginning) + generation requests
7373 numGenerationRequest = batchSize - numContextRequest
7474
7575 inputIds: [N]
76- - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
77- + 1)
78- - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
76+ - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1)
7977 seqLens: [batchSize]
8078 mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
8179 mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
82- previousLayerHiddenStates: [N, hiddenSize]
83- - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
84- + 1) (from target model)
85- - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
86- previousLayerDraftTokens: [batchSize], the draft tokens generated by the previous layer
80+ hiddenStates: [N, hiddenSize]
81+ - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) (from target model)
82+ acceptedTokens: [batchSize, numMTPModules + 1]
83+ numAcceptedTokens: [batchSize]
8784 returnInputIds: [N]
8885 - N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
8986 returnHiddenStates: [N, hiddenSize]
@@ -94,6 +91,7 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
9491
9592 T const * curMTPPastHiddenStatesPtr = mtpPastHiddenStatesPtrs[bid];
9693 int const * curMTPPastTokensPtr = mtpPastTokensPtrs[bid];
94+ int const * curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1 );
9795
9896 int curSeqLen = seqLens[bid];
9997
@@ -117,63 +115,44 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
117115 }
118116
119117 int const * curInputIdsPtr = inputIds + inputIdsStartOffset;
120- T const * curPreviousLayerHiddenStates = previousLayerHiddenStates + inputIdsStartOffset * hiddenSize;
118+ T const * curHiddenStates = hiddenStates + inputIdsStartOffset * hiddenSize;
121119
122120 int * curReturnInputIdsPtr = returnInputIds + returnInputIdsStartOffset;
123121 T* curReturnHiddenStatesIdsPtr = returnHiddenStates + returnInputIdsStartOffset * hiddenSize;
124122
125123 // // main logic
126-
127- if (curMTPLayerIdx == 0 )
124+ if (bid < numContextRequest)
128125 {
129- if (bid < numContextRequest)
126+ // context requests
127+ if (tid == 0 )
130128 {
131- // context requests
132- if (tid == 0 )
129+ // 1) For the new inputIds
130+ for ( int ii = 0 ; ii < curSeqLen - 1 ; ii++ )
133131 {
134- // 1) For the new inputIds
135- for (int ii = 0 ; ii < curSeqLen - 1 ; ii++)
136- {
137- curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1 ]; // +1 because of offset 1, prompt[1:]
138- }
139- // Append the latest golden token, i.e., the last one in the past tokens list
140- curReturnInputIdsPtr[curSeqLen - 1 ] = curMTPPastTokensPtr[numMTPModules - 1 ];
132+ curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1 ]; // +1 because of offset 1, prompt[1:]
141133 }
142-
143- // 2) For the new past hidden states
144- copyChunkedHiddenStates (curPreviousLayerHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
134+ // Append the latest golden token, i.e., the first one in the accepted tokens list
135+ curReturnInputIdsPtr[curSeqLen - 1 ] = curAcceptedTokensPtr[0 ];
145136 }
146- else
147- {
148- // generation requests
149- if (tid == 0 )
150- {
151- // 1) For the new inputIds
152- for (int ii = 0 ; ii < numMTPModules; ii++)
153- {
154- curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii];
155- }
156- }
157137
158- // 2) For the new past hidden states
159- copyChunkedHiddenStates (curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
160- }
138+ // 2) For the new past hidden states
139+ copyChunkedHiddenStates (curHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
161140 }
162- else // For curMTPLayerIdx > 0
141+ else
163142 {
143+ // generation requests
164144 if (tid == 0 )
165145 {
166146 // 1) For the new inputIds
167- int numPastTokens = (bid < numContextRequest) ? curSeqLen : numMTPModules;
168- for (int ii = 0 ; ii < numPastTokens; ii++)
147+ for (int ii = 0 ; ii < numMTPModules - 1 ; ii++)
169148 {
170- curReturnInputIdsPtr[ii] = curInputIdsPtr [ii + 1 ];
149+ curReturnInputIdsPtr[ii] = curMTPPastTokensPtr [ii + 1 ];
171150 }
172- curReturnInputIdsPtr[numPastTokens - 1 ] = previousLayerDraftTokens[ bid];
151+ curReturnInputIdsPtr[numMTPModules - 1 ] = curAcceptedTokensPtr[numAcceptedTokens[ bid] - 1 ];
173152 }
174153
175154 // 2) For the new past hidden states
176- // Directly use previous layer's output hidden states
155+ copyChunkedHiddenStates (curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
177156 }
178157}
179158
@@ -185,10 +164,10 @@ void invokeMTPPrepareDrafterInputs(MTPPrepareDrafterInputsParam& params, cudaStr
185164 params.hiddenSize * sizeof (T) % 16 == 0 ); // Which is because we will use float4 to copy the hidden states.
186165
187166 mtpPrepareDrafterInputsKernel<T><<<params.batchSize, BLOCK_SIZE, 0 , stream>>> (params.numMTPModules ,
188- params.curMTPLayerIdx , params.batchSize , params.numContextRequest , params.hiddenSize , params. inputIds ,
189- params. seqLens , reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ), params.mtpPastTokensPtrs ,
190- reinterpret_cast <T*>(params.previousLayerHiddenStates ), params.previousLayerDraftTokens , params.returnInputIds ,
191- reinterpret_cast <T*>(params.returnHiddenStates ));
167+ params.numContextRequest , params.hiddenSize , params.inputIds , params.seqLens ,
168+ reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ), params.mtpPastTokensPtrs ,
169+ reinterpret_cast <T*>(params.hiddenStates ), params.acceptedTokens , params.numAcceptedTokens ,
170+ params. returnInputIds , reinterpret_cast <T*>(params.returnHiddenStates ));
192171
193172 sync_check_cuda_error (stream);
194173}
@@ -362,7 +341,7 @@ template void invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(
362341template <typename T>
363342__global__ void mtpUpdateHiddenStatesKernel (int const numMTPModules, int const batchSize, int const numContextRequest,
364343 int const hiddenSize, int const * inputIds, int const * seqLens, T* targetModelHiddenStates,
365- T** mtpPastHiddenStatesPtrs, int ** mtpPastTokensPtrs, int const * numAcceptedTokens, int const * acceptedTokens )
344+ T** mtpPastHiddenStatesPtrs, int ** mtpPastTokensPtrs, int const * numAcceptedTokens)
366345{
367346 /*
368347 In a batch of request: context request (at the beginning) + generation requests
@@ -374,7 +353,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
374353 mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
375354 mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
376355 numAcceptedTokens: [batchSize]
377- acceptedTokens: [batchSize][numMTPModules + 1], flatten
378356 */
379357
380358 int const bid = static_cast <int >(blockIdx .x ); // Each block is responsible for a request.
@@ -395,7 +373,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
395373
396374 auto curInputIdsPtr = inputIds + inputIdsStartOffset;
397375 auto curTargetModelHiddenStatesPtr = targetModelHiddenStates + inputIdsStartOffset * hiddenSize;
398- auto curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1 );
399376
400377 // Update MTP tokens
401378 // Just use one thread to execute this copy
@@ -405,12 +382,10 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
405382 {
406383 // Context request
407384 // Copy the end of prompt tokens
408- for (int ii = 0 ; ii < numMTPModules - 1 ; ii++)
385+ for (int ii = 0 ; ii < numMTPModules; ii++)
409386 {
410- curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + 1 + ii];
387+ curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + ii];
411388 }
412- // Copy the new generated golden token
413- curMTPPastTokensPtr[numMTPModules - 1 ] = curAcceptedTokensPtr[0 ];
414389 }
415390 else
416391 {
@@ -424,7 +399,7 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
424399 int acceptedTokenStartIdx = max (0 , curAcceptedLen - numMTPModules);
425400 for (; ii < numMTPModules; ii++, acceptedTokenStartIdx++)
426401 {
427- curMTPPastTokensPtr[ii] = curAcceptedTokensPtr [acceptedTokenStartIdx];
402+ curMTPPastTokensPtr[ii] = curInputIdsPtr [acceptedTokenStartIdx];
428403 }
429404 }
430405 }
@@ -463,7 +438,7 @@ void invokeMTPUpdateHiddenStates(MTPUpdateHiddenStatesParam& params, cudaStream_
463438 mtpUpdateHiddenStatesKernel<T><<<params.batchSize, BLOCK_SIZE, 0 , stream>>> (params.numMTPModules , params.batchSize ,
464439 params.numContextRequest , params.hiddenSize , params.inputIds , params.seqLens ,
465440 reinterpret_cast <T*>(params.targetModelHiddenStates ), reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ),
466- params.mtpPastTokensPtrs , params.numAcceptedTokens , params. acceptedTokens );
441+ params.mtpPastTokensPtrs , params.numAcceptedTokens );
467442 sync_check_cuda_error (stream);
468443}
469444
0 commit comments