Skip to content

Commit f3404b2

Browse files
Merge branch 'main' into user/nzmora/moe_unit_test_change
2 parents f4d3e80 + 34f845b commit f3404b2

File tree

1,619 files changed

+28282
-12444
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,619 files changed

+28282
-12444
lines changed

cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class DecoderInputBuffers
3838
public:
3939
using SizeType32 = runtime::SizeType32;
4040
using TensorPtr = runtime::ITensor::SharedPtr;
41+
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
4142

4243
explicit DecoderInputBuffers(
4344
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager);
@@ -60,13 +61,22 @@ class DecoderInputBuffers
6061
//! Requests for considered in decoder forward
6162
RequestVector decoderRequests;
6263

64+
//! Logits of decoder requests
65+
std::vector<TensorPtr> decoderLogits;
66+
67+
//! Maximum number of decoding steps of decoder requests.
68+
//! This is only more than 1 for external draft tokens speculative decoding.
69+
SizeType32 maxDecoderSteps{1};
70+
6371
//! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize]
6472
std::vector<TensorPtr> forwardBatchSlots;
6573

66-
//! Logits of decoder requests
67-
std::vector<TensorPtr> logits;
74+
//! Logits for requests in forwardBatchSlots (in the same order).
75+
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
76+
std::vector<std::vector<TensorConstPtr>> batchLogits;
6877

69-
//! Logits for speculative decoding (Medusa)
78+
//! Logits for speculative decoding (Medusa).
79+
//! The vector is sparse, only slots in forwardBatchSlots are used.
7080
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
7181
std::vector<std::vector<runtime::ITensor::SharedPtr>> predictedDraftLogits;
7282
};

cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,17 @@ class MakeDecodingBatchInputOutput : Algorithm
4040
constexpr static auto name{"MakeDecodingBatchInputOutput"};
4141

4242
using SizeType32 = tensorrt_llm::runtime::SizeType32;
43-
using TensorPtr = runtime::decoder_batch::Input::TensorPtr;
43+
using TensorPtr = runtime::ITensor::SharedPtr;
4444
template <typename T>
4545
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
4646

4747
MakeDecodingBatchInputOutput() = default;
4848

49-
std::unique_ptr<runtime::decoder_batch::Input> operator()(DecoderInputBuffers& inputBuffers,
50-
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
51-
SizeType32 maxNumSequences, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
49+
void operator()(DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
50+
runtime::ModelConfig const& modelConfig, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
5251

53-
[[nodiscard]] static std::unique_ptr<runtime::decoder_batch::Input> createDecoderBatchInputs(
54-
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
55-
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots);
52+
static void createDecoderBatchInputs(DecoderInputBuffers& inputBuffers, std::vector<SizeType32> const& activeSlots,
53+
runtime::decoder::DecoderState const& decoderState);
5654
};
5755

5856
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
2020
#include "tensorrt_llm/common/cudaDriverWrapper.h"
2121
#include "tensorrt_llm/common/cudaFp8Utils.h"
22+
#if ENABLE_FP4
23+
#include <cuda_fp4.h>
24+
#endif
2225
#include "tensorrt_llm/common/logger.h"
2326
#include "tensorrt_llm/common/tllmException.h"
2427
#include <algorithm>
@@ -545,6 +548,9 @@ template void printArrayInfo(__nv_bfloat16 const* ptr, uint64_t nElement, std::s
545548
#ifdef ENABLE_FP8
546549
template void printArrayInfo(__nv_fp8_e4m3 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
547550
#endif
551+
#ifdef ENABLE_FP4
552+
template void printArrayInfo(__nv_fp4_e2m1 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
553+
#endif
548554
template void printArrayInfo(uint32_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
549555
template void printArrayInfo(uint64_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
550556
template void printArrayInfo(int const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);

cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class GptDecoderBatched : public IGptDecoderBatched
5252

5353
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
5454

55-
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
56-
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
55+
CudaEvent forwardAsync(
56+
decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override;
57+
void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override;
5758

5859
//! @brief Gather final beam search results for request `batchSlot`.
5960
//! Result will only be available after event returned.
@@ -77,7 +78,7 @@ class GptDecoderBatched : public IGptDecoderBatched
7778

7879
private:
7980
//! @brief Calls decoders for tokens per engine step
80-
void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);
81+
void forwardDispatch(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input);
8182

8283
private:
8384
CudaStreamPtr mRuntimeStream;

cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727

2828
namespace tensorrt_llm::batch_manager
2929
{
30+
class DecoderInputBuffers;
3031
class LlmRequest;
31-
}
32+
} // namespace tensorrt_llm::batch_manager
3233

3334
namespace tensorrt_llm::runtime
3435
{
@@ -39,43 +40,6 @@ namespace decoder
3940
class DecoderState;
4041
}
4142

42-
namespace decoder_batch
43-
{
44-
45-
class Input
46-
{
47-
public:
48-
using TensorConstPtr = ITensor::SharedConstPtr;
49-
using TensorPtr = ITensor::SharedPtr;
50-
51-
explicit Input(std::vector<std::vector<TensorConstPtr>> const& logits, SizeType32 maxDecoderSteps)
52-
: logits{logits}
53-
, maxDecoderSteps{maxDecoderSteps}
54-
{
55-
TLLM_CHECK_WITH_INFO(
56-
logits.size() == static_cast<size_t>(maxDecoderSteps), "logits vector size does not match maxDecoderSteps");
57-
}
58-
59-
explicit Input(std::vector<TensorConstPtr> const& logits)
60-
: Input{{logits}, 1}
61-
{
62-
}
63-
64-
//! Mandatory parameters
65-
//! Logits
66-
// FIXME: remove first dimension of tensors
67-
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
68-
std::vector<std::vector<TensorConstPtr>> logits;
69-
70-
//! Maximum number of decoding tokens of active slots
71-
SizeType32 maxDecoderSteps;
72-
73-
//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
74-
std::vector<TensorPtr> batchSlots;
75-
};
76-
77-
} // namespace decoder_batch
78-
7943
//! GPT decoder class with support for in-flight batching
8044
class IGptDecoderBatched
8145
{
@@ -94,10 +58,13 @@ class IGptDecoderBatched
9458
virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0;
9559

9660
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
97-
virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
61+
virtual CudaEvent forwardAsync(
62+
decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input)
63+
= 0;
9864

9965
//! @brief Run one step for all requests and wait for completion on the host.
100-
virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
66+
virtual void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input)
67+
= 0;
10168

10269
//! @brief Gather final beam search results for request `batchIdx`.
10370
//! Result will only be available after event returned

cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ void GuidedDecoder::execute(DecoderInputBuffers const& decoderInputBuffers, Buff
182182
{
183183
auto const seqSlot = llmReq->mSeqSlot.value();
184184

185-
auto const& logits = decoderInputBuffers.logits.at(requestIdx);
185+
auto const& logits = decoderInputBuffers.decoderLogits.at(requestIdx);
186186
auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot});
187187

188188
// Use void* to unify the code for different mLogitsDtype

cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re
7979
auto& decoderRequests = inputBuffers.decoderRequests;
8080
decoderRequests.clear();
8181
decoderRequests.reserve(contextRequests.size());
82-
auto& allDecoderLogits = inputBuffers.logits;
82+
auto& allDecoderLogits = inputBuffers.decoderLogits;
8383
allDecoderLogits.clear();
8484
allDecoderLogits.reserve(contextRequests.size());
8585

cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque
8585

8686
auto& decoderRequests = inputBuffers.decoderRequests;
8787
decoderRequests.reserve(decoderRequests.size() + generationRequests.size());
88-
auto& allDecoderLogits = inputBuffers.logits;
88+
auto& allDecoderLogits = inputBuffers.decoderLogits;
8989
allDecoderLogits.reserve(allDecoderLogits.size() + generationRequests.size());
9090

9191
for (auto const& llmReq : generationRequests)

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ void LlmRequest::createSerializedResult(
6969
/// Note that there is some dependency on the order of operations in this method. Modify with care!
7070
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
7171
{
72-
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
72+
auto const streamingInProgress = mIsStreaming
73+
&& (mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE);
74+
if (!(isFinished() || streamingInProgress))
7375
{
7476
return std::nullopt;
7577
}

cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool rep
4949
for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx)
5050
{
5151
auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx);
52-
auto& logits = inputBuffers.logits.at(batchIdx);
52+
auto& logits = inputBuffers.decoderLogits.at(batchIdx);
5353

5454
// Invoke non-batched processor or collect arguments for batched processor
5555
if (llmReq->mLogitsPostProcessor)

0 commit comments

Comments
 (0)