@@ -122,7 +122,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
122122
123123} // namespace
124124
125- std::tuple<TensorPtr, std::vector<runtime::decoder_batch::Request>, std::vector<runtime::SamplingConfig>>
125+ std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
126+ std::vector<executor::LookaheadDecodingConfig>>
126127CreateNewDecoderRequests::operator ()(runtime::ModelConfig const & modelConfig, runtime::WorldConfig const & worldConfig,
127128 executor::DecodingConfig const & decodingConfig, RequestVector const & contextRequests,
128129 runtime::BufferManager const & bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
@@ -139,9 +140,9 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
139140 copySequenceLengths (finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths (), beamWidth,
140141 bufferManager, runtimeStream);
141142
142- auto decoderRequests = createDecoderRequests (finishedContextRequests, inputBuffers. inputsIds , decodingConfig ,
143- decoderState, bufferManager, logitsType, modelConfig, worldConfig, runtimeStream, decoderStream ,
144- maxSequenceLength, medusaBuffers);
143+ auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests (finishedContextRequests,
144+ inputBuffers. inputsIds , decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
145+ runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);
145146
146147 auto const batchSize = finishedContextRequests.size ();
147148
@@ -155,7 +156,8 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
155156 TensorPtr batchSlotsView = runtime::ITensor::slice (inputBuffers.setupBatchSlots , 0 , batchSize);
156157
157158 TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
158- return {std::move (batchSlotsView), std::move (decoderRequests), std::move (samplingConfigs)};
159+ return {std::move (batchSlotsView), std::move (samplingConfigs), std::move (lookaheadPrompt),
160+ std::move (lookaheadAlgoConfigs)};
159161}
160162
161163void CreateNewDecoderRequests::newRequest (SizeType32 batchSlot, runtime::decoder_batch::Request const & request,
@@ -555,8 +557,8 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
555557 TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
556558}
557559
558- [[nodiscard]] std::vector<runtime::decoder_batch::Request> CreateNewDecoderRequests::createDecoderRequests (
559- RequestVector const & finishedContextRequests, TensorPtr const & inputIds,
560+ std::tuple<std:: vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
561+ CreateNewDecoderRequests::createDecoderRequests ( RequestVector const & finishedContextRequests, TensorPtr const & inputIds,
560562 executor::DecodingConfig const & decodingConfig, runtime::decoder::DecoderState& decoderState,
561563 BufferManager const & bufferManager, nvinfer1::DataType logitsType, runtime::ModelConfig const & modelConfig,
562564 runtime::WorldConfig const & worldConfig, runtime::CudaStream const & runtimeStream,
@@ -574,6 +576,16 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
574576 std::vector<decoder_batch::Request> decoderRequests;
575577 decoderRequests.reserve (finishedContextRequests.size ());
576578
579+ std::vector<runtime::ITensor::SharedConstPtr> lookaheadPrompt;
580+ std::vector<executor::LookaheadDecodingConfig> lookaheadAlgoConfigs;
581+ if (modelConfig.getSpeculativeDecodingMode ().isLookaheadDecoding ())
582+ {
583+ TLLM_CHECK_WITH_INFO (
584+ decodingConfig.getLookaheadDecodingConfig ().has_value (), " Lookahead decoding config must be provided" );
585+ lookaheadPrompt.reserve (finishedContextRequests.size ());
586+ lookaheadAlgoConfigs.reserve (finishedContextRequests.size ());
587+ }
588+
577589 SizeType32 inputOffset{0 };
578590 for (auto const & llmReq : finishedContextRequests)
579591 {
@@ -620,14 +632,11 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
620632 }
621633 else if (modelConfig.getSpeculativeDecodingMode ().isLookaheadDecoding ())
622634 {
623- decoderRequest.lookaheadRuntimeConfig = llmReq->getLookaheadConfig ()
624- ? llmReq->getLookaheadConfig ()
625- : decodingConfig.getLookaheadDecodingConfig ();
626- }
627- else if (modelConfig.getSpeculativeDecodingMode ().isExplicitDraftTokens ())
628- {
629- // Only Explicit draft tokens model needs dtype to WAR the lack of bf16 decoder.
630- decoderRequest.dtype = modelConfig.getDataType ();
635+ lookaheadPrompt.emplace_back (ITensor::slice (decoderRequest.ids , 0 , decoderRequest.inputLen ));
636+
637+ auto const & lookaheadRuntimeConfig
638+ = llmReq->getLookaheadConfig ().value_or (decodingConfig.getLookaheadDecodingConfig ().value ());
639+ lookaheadAlgoConfigs.emplace_back (lookaheadRuntimeConfig);
631640 }
632641 else if (modelConfig.getSpeculativeDecodingMode ().isEagle ())
633642 {
@@ -659,7 +668,7 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
659668 inputOffset += promptLen;
660669 }
661670
662- return decoderRequests ;
671+ return { std::move (lookaheadPrompt), std::move (lookaheadAlgoConfigs)} ;
663672}
664673
665674std::shared_ptr<runtime::ITensor> CreateNewDecoderRequests::retrieveDraftLogits (ModelConfig const & modelConfig,
0 commit comments