diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index f9ccc4b0d..72d060954 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -19,7 +19,8 @@ def main(args): batch_size = len(prompts) config = og.Config(args.model_path) - config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}}}}}') + # Example: Configure search parameters including chunk_size for prefix chunking + config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}, "chunk_size": {args.chunk_size}}}}}') if args.execution_provider != "follow_config": config.clear_providers() @@ -90,6 +91,7 @@ def main(args): parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false') parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.') + parser.add_argument('--chunk_size', type=int, default=-1, help='Chunk size for prefix chunking during context processing (default: -1 = disabled, >0 = enabled)') parser.add_argument('--non-interactive', action=argparse.BooleanOptionalAction, required=False, default=False, help='Non-interactive mode, mainly for CI usage') args = parser.parse_args() diff --git a/src/config.cpp b/src/config.cpp index e3e115b8e..bde7e68f4 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -868,6 +868,8 @@ struct Search_Element : JSON::Element { v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "random_seed") { v_.random_seed = SafeDoubleToInt(JSON::Get(value), name); + } else if (name == "chunk_size") { + v_.chunk_size = static_cast(JSON::Get(value)); } else if (name == "do_sample") { v_.do_sample = JSON::Get(value); } else if (name == "past_present_share_buffer") { diff --git a/src/config.h b/src/config.h index bc856e1c0..d0064379c 100644 --- a/src/config.h +++ b/src/config.h @@ -273,6 +273,7 @@ struct Config { float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. bool past_present_share_buffer{}; // The past/present kv tensors are shared and allocated once to max_length (cuda only) int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG + int chunk_size{-1}; // Chunk size for prefix chunking during context processing. -1 = disabled, >0 = enabled with specified chunk size. } search; void AddMapping(const std::string& nominal_name, const std::string& graph_name); diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 9bf3ab871..c9ac6edb0 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -28,13 +28,47 @@ void DecoderOnly_State::SetExtraInputs(const std::vector& extra_inpu } DeviceSpan DecoderOnly_State::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { - UpdateInputsOutputs(next_tokens, next_indices, total_length); + size_t num_tokens = next_tokens.size(); + const size_t chunk_size = static_cast(model_.config_->search.chunk_size); + + // Enable prefill chunking for CUDA and NvTensorRtRtx devices + bool is_chunking_supported_device = (model_.p_device_->GetType() == DeviceType::CUDA || + model_.p_device_->GetType() == DeviceType::NvTensorRtRtx); - // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token. - bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1; - State::Run(*model_.session_decoder_, graph_capture_this_run); + if (is_chunking_supported_device && chunk_size > 0 && num_tokens > chunk_size) { + // Chunking logic for context phase - process in chunks based on configured chunk_size + size_t processed_tokens = 0; + int length = total_length - static_cast(num_tokens); + while (processed_tokens < num_tokens) { + size_t current_chunk_size = std::min(chunk_size, num_tokens - processed_tokens); + + // Create subspans for current chunk + auto chunk_tokens = next_tokens.subspan(processed_tokens, current_chunk_size); + //auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size); + length = length + static_cast(current_chunk_size); + // Process this chunk - fills KV cache progressively + UpdateInputsOutputs(chunk_tokens, next_indices, length); + + // Graph capture is typically disabled during context phase chunking + bool graph_capture_this_run = false; // Disable graph capture during chunking + State::Run(*model_.session_decoder_, graph_capture_this_run); + + processed_tokens += current_chunk_size; + } + + // Return logits from the last chunk for potential sampling + return logits_.Get(); + } else { + // Original logic for tokens <= chunk_size (generation phase or small context) + // or chunking disabled due to unsupported device + UpdateInputsOutputs(next_tokens, next_indices, total_length); - return logits_.Get(); + // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token. + bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1; + State::Run(*model_.session_decoder_, graph_capture_this_run); + + return logits_.Get(); + } } void DecoderOnly_State::RewindTo(size_t index) {