Skip to content

Commit c44c266

Browse files
committed
Enable prefill chunking only for the NvTensorRt and Cuda
1 parent 7227a26 commit c44c266

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/models/decoder_only.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>&
3131
size_t num_tokens = next_tokens.size();
3232
const size_t chunk_size = static_cast<size_t>(model_.config_->search.chunk_size);
3333

34-
if (chunk_size > 0 && num_tokens > chunk_size) {
34+
// Enable prefill chunking for CUDA and NvTensorRtRtx devices
35+
bool is_chunking_supported_device = (model_.p_device_->GetType() == DeviceType::CUDA ||
36+
model_.p_device_->GetType() == DeviceType::NvTensorRtRtx);
37+
38+
if (is_chunking_supported_device && chunk_size > 0 && num_tokens > chunk_size) {
3539
// Chunking logic for context phase - process in chunks based on configured chunk_size
3640
size_t processed_tokens = 0;
3741
int length = total_length - static_cast<int>(num_tokens);
@@ -55,7 +59,8 @@ DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>&
5559
// Return logits from the last chunk for potential sampling
5660
return logits_.Get();
5761
} else {
58-
// Original logic for tokens <= 512 (generation phase or small context)
62+
// Original logic for tokens <= chunk_size (generation phase or small context)
63+
// or chunking disabled due to unsupported device
5964
UpdateInputsOutputs(next_tokens, next_indices, total_length);
6065

6166
// Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.

0 commit comments

Comments
 (0)