2727
2828namespace tensorrt_llm ::batch_manager
2929{
30+ class DecoderInputBuffers ;
3031class LlmRequest ;
31- }
32+ } // namespace tensorrt_llm::batch_manager
3233
3334namespace tensorrt_llm ::runtime
3435{
@@ -39,43 +40,6 @@ namespace decoder
3940class DecoderState ;
4041}
4142
42- namespace decoder_batch
43- {
44-
45- class Input
46- {
47- public:
48- using TensorConstPtr = ITensor::SharedConstPtr;
49- using TensorPtr = ITensor::SharedPtr;
50-
51- explicit Input (std::vector<std::vector<TensorConstPtr>> const & logits, SizeType32 maxDecoderSteps)
52- : logits{logits}
53- , maxDecoderSteps{maxDecoderSteps}
54- {
55- TLLM_CHECK_WITH_INFO (
56- logits.size () == static_cast <size_t >(maxDecoderSteps), " logits vector size does not match maxDecoderSteps" );
57- }
58-
59- explicit Input (std::vector<TensorConstPtr> const & logits)
60- : Input{{logits}, 1 }
61- {
62- }
63-
64- // ! Mandatory parameters
65- // ! Logits
66- // FIXME: remove first dimension of tensors
67- // ! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
68- std::vector<std::vector<TensorConstPtr>> logits;
69-
70- // ! Maximum number of decoding tokens of active slots
71- SizeType32 maxDecoderSteps;
72-
73- // ! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
74- std::vector<TensorPtr> batchSlots;
75- };
76-
77- } // namespace decoder_batch
78-
7943// ! GPT decoder class with support for in-flight batching
8044class IGptDecoderBatched
8145{
@@ -94,10 +58,13 @@ class IGptDecoderBatched
9458 virtual void disableLookahead (RequestVector const & genRequests, TensorPtr const & batchSlots) = 0;
9559
9660 // ! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
97- virtual CudaEvent forwardAsync (decoder::DecoderState const & decoderState, decoder_batch::Input const & input) = 0;
61+ virtual CudaEvent forwardAsync (
62+ decoder::DecoderState const & decoderState, batch_manager::DecoderInputBuffers const & input)
63+ = 0;
9864
9965 // ! @brief Run one step for all requests and wait for completion on the host.
100- virtual void forward (decoder::DecoderState const & decoderState, decoder_batch::Input const & input) = 0;
66+ virtual void forward (decoder::DecoderState const & decoderState, batch_manager::DecoderInputBuffers const & input)
67+ = 0;
10168
10269 // ! @brief Gather final beam search results for request `batchIdx`.
10370 // ! Result will only be available after event returned
0 commit comments