Skip to content

Commit e72768d

Browse files
committed
Enable prefix chunking though the config overlay
1 parent 17d7549 commit e72768d

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

examples/python/model-generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def main(args):
1919
batch_size = len(prompts)
2020

2121
config = og.Config(args.model_path)
22-
config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}}}}}')
22+
# Example: Configure search parameters including chunk_size for prefix chunking
23+
config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}, "chunk_size": {args.chunk_size}}}}}')
2324

2425
if args.execution_provider != "follow_config":
2526
config.clear_providers()
@@ -90,6 +91,7 @@ def main(args):
9091
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
9192
parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph')
9293
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.')
94+
parser.add_argument('--chunk_size', type=int, default=-1, help='Chunk size for prefix chunking during context processing (default: -1 = disabled, >0 = enabled)')
9395
parser.add_argument('--non-interactive', action=argparse.BooleanOptionalAction, required=False, default=False, help='Non-interactive mode, mainly for CI usage')
9496

9597
args = parser.parse_args()

src/config.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,8 @@ struct Search_Element : JSON::Element {
778778
v_.length_penalty = static_cast<float>(JSON::Get<double>(value));
779779
} else if (name == "random_seed") {
780780
v_.random_seed = static_cast<int>(JSON::Get<double>(value));
781+
} else if (name == "chunk_size") {
782+
v_.chunk_size = static_cast<int>(JSON::Get<double>(value));
781783
} else if (name == "do_sample") {
782784
v_.do_sample = JSON::Get<bool>(value);
783785
} else if (name == "past_present_share_buffer") {

src/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ struct Config {
266266
float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
267267
bool past_present_share_buffer{}; // The past/present kv tensors are shared and allocated once to max_length (cuda only)
268268
int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG
269+
int chunk_size{-1}; // Chunk size for prefix chunking during context processing. -1 = disabled, >0 = enabled with specified chunk size.
269270
} search;
270271

271272
void AddMapping(const std::string& nominal_name, const std::string& graph_name);

src/models/decoder_only.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ void DecoderOnly_State::SetExtraInputs(const std::vector<ExtraInput>& extra_inpu
2929

3030
DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices) {
3131
size_t num_tokens = next_tokens.size();
32-
const size_t chunk_size = 1024; // Experimental value
32+
const size_t chunk_size = static_cast<size_t>(model_.config_->search.chunk_size);
3333

34-
if (num_tokens > chunk_size) {
35-
// Chunking logic for context phase - process in chunks of 512 tokens
34+
if (chunk_size > 0 && num_tokens > chunk_size) {
35+
// Chunking logic for context phase - process in chunks based on configured chunk_size
3636
size_t processed_tokens = 0;
3737
int length = total_length - static_cast<int>(num_tokens);
3838
while (processed_tokens < num_tokens) {

0 commit comments

Comments
 (0)