From 48999386fa7568804db8e376958f3fc285e4ef0b Mon Sep 17 00:00:00 2001 From: anujj Date: Wed, 11 Jun 2025 20:07:41 +0530 Subject: [PATCH 1/4] Experimental: Add support for prefix-chunking --- src/models/decoder_only.cpp | 39 ++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 9bf3ab871..c9c6522eb 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -28,13 +28,42 @@ void DecoderOnly_State::SetExtraInputs(const std::vector& extra_inpu } DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { - UpdateInputsOutputs(next_tokens, next_indices, total_length); + size_t num_tokens = next_tokens.size(); + const size_t chunk_size = 15; + + if (num_tokens > chunk_size) { + // Chunking logic for context phase - process in chunks of 512 tokens + size_t processed_tokens = 0; + int length = total_length - static_cast(num_tokens); + while (processed_tokens < num_tokens) { + size_t current_chunk_size = std::min(chunk_size, num_tokens - processed_tokens); + + // Create subspans for current chunk + auto chunk_tokens = next_tokens.subspan(processed_tokens, current_chunk_size); + //auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size); + length = length + static_cast(current_chunk_size); + // Process this chunk - fills KV cache progressively + UpdateInputsOutputs(chunk_tokens, next_indices, length); + + // Graph capture is typically disabled during context phase chunking + bool graph_capture_this_run = false; // Disable graph capture during chunking + State::Run(*model_.session_decoder_, graph_capture_this_run); + + processed_tokens += current_chunk_size; + } + + // Return logits from the last chunk for potential sampling + return logits_.Get(); + } else { + // Original logic for tokens <= 512 (generation phase or small context) + UpdateInputsOutputs(next_tokens, next_indices, total_length); - // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token. - bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1; - State::Run(*model_.session_decoder_, graph_capture_this_run); + // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token. + bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1; + State::Run(*model_.session_decoder_, graph_capture_this_run); - return logits_.Get(); + return logits_.Get(); + } } void DecoderOnly_State::RewindTo(size_t index) { From a71397353576a5dd61fae29b6e6b0456b554aac6 Mon Sep 17 00:00:00 2001 From: anujj Date: Wed, 11 Jun 2025 20:09:45 +0530 Subject: [PATCH 2/4] Update decoder_only.cpp --- src/models/decoder_only.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index c9c6522eb..28b7ab386 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -29,7 +29,7 @@ void DecoderOnly_State::SetExtraInputs(const std::vector& extra_inpu DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { size_t num_tokens = next_tokens.size(); - const size_t chunk_size = 15; + const size_t chunk_size = 1024; // Experimental value if (num_tokens > chunk_size) { // Chunking logic for context phase - process in chunks of 512 tokens From 7227a263e4cac3433df82b1762baf4422a49147b Mon Sep 17 00:00:00 2001 From: Anuj Jalota Date: Mon, 15 Sep 2025 20:28:16 +0530 Subject: [PATCH 3/4] Enable prefix chunking though the config overlay --- examples/python/model-generate.py | 4 +++- src/config.cpp | 2 ++ src/config.h | 1 + src/models/decoder_only.cpp | 6 +++--- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index f9ccc4b0d..72d060954 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -19,7 +19,8 @@ def main(args): batch_size = len(prompts) config = og.Config(args.model_path) - config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}}}}}') + # Example: Configure search parameters including chunk_size for prefix chunking + config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}, "chunk_size": {args.chunk_size}}}}}') if args.execution_provider != "follow_config": config.clear_providers() @@ -90,6 +91,7 @@ def main(args): parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false') parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.') + parser.add_argument('--chunk_size', type=int, default=-1, help='Chunk size for prefix chunking during context processing (default: -1 = disabled, >0 = enabled)') parser.add_argument('--non-interactive', action=argparse.BooleanOptionalAction, required=False, default=False, help='Non-interactive mode, mainly for CI usage') args = parser.parse_args() diff --git a/src/config.cpp b/src/config.cpp index e3e115b8e..bde7e68f4 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -868,6 +868,8 @@ struct Search_Element : JSON::Element { v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "random_seed") { v_.random_seed = SafeDoubleToInt(JSON::Get(value), name); + } else if (name == "chunk_size") { + v_.chunk_size = static_cast(JSON::Get(value)); } else if (name == "do_sample") { v_.do_sample = JSON::Get(value); } else if (name == "past_present_share_buffer") { diff --git a/src/config.h b/src/config.h index bc856e1c0..d0064379c 100644 --- a/src/config.h +++ b/src/config.h @@ -273,6 +273,7 @@ struct Config { float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. bool past_present_share_buffer{}; // The past/present kv tensors are shared and allocated once to max_length (cuda only) int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG + int chunk_size{-1}; // Chunk size for prefix chunking during context processing. -1 = disabled, >0 = enabled with specified chunk size. } search; void AddMapping(const std::string& nominal_name, const std::string& graph_name); diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 28b7ab386..6196605db 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -29,10 +29,10 @@ void DecoderOnly_State::SetExtraInputs(const std::vector& extra_inpu DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { size_t num_tokens = next_tokens.size(); - const size_t chunk_size = 1024; // Experimental value + const size_t chunk_size = static_cast(model_.config_->search.chunk_size); - if (num_tokens > chunk_size) { - // Chunking logic for context phase - process in chunks of 512 tokens + if (chunk_size > 0 && num_tokens > chunk_size) { + // Chunking logic for context phase - process in chunks based on configured chunk_size size_t processed_tokens = 0; int length = total_length - static_cast(num_tokens); while (processed_tokens < num_tokens) { From c44c26696954db254979f0fed76dca9bc072eb58 Mon Sep 17 00:00:00 2001 From: Anuj Jalota Date: Tue, 16 Sep 2025 22:01:00 +0530 Subject: [PATCH 4/4] Enable prefill chunking only for the NvTensorRt and Cuda --- src/models/decoder_only.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 6196605db..c9ac6edb0 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -31,7 +31,11 @@ DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& size_t num_tokens = next_tokens.size(); const size_t chunk_size = static_cast(model_.config_->search.chunk_size); - if (chunk_size > 0 && num_tokens > chunk_size) { + // Enable prefill chunking for CUDA and NvTensorRtRtx devices + bool is_chunking_supported_device = (model_.p_device_->GetType() == DeviceType::CUDA || + model_.p_device_->GetType() == DeviceType::NvTensorRtRtx); + + if (is_chunking_supported_device && chunk_size > 0 && num_tokens > chunk_size) { // Chunking logic for context phase - process in chunks based on configured chunk_size size_t processed_tokens = 0; int length = total_length - static_cast(num_tokens); @@ -55,7 +59,8 @@ DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& // Return logits from the last chunk for potential sampling return logits_.Get(); } else { - // Original logic for tokens <= 512 (generation phase or small context) + // Original logic for tokens <= chunk_size (generation phase or small context) + // or chunking disabled due to unsupported device UpdateInputsOutputs(next_tokens, next_indices, total_length); // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.