Skip to content

Commit 7227a26

Browse files
committed
Enable prefix chunking though the config overlay
1 parent a713973 commit 7227a26

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

examples/python/model-generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def main(args):
1919
batch_size = len(prompts)
2020

2121
config = og.Config(args.model_path)
22-
config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}}}}}')
22+
# Example: Configure search parameters including chunk_size for prefix chunking
23+
config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}, "chunk_size": {args.chunk_size}}}}}')
2324

2425
if args.execution_provider != "follow_config":
2526
config.clear_providers()
@@ -90,6 +91,7 @@ def main(args):
9091
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
9192
parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph')
9293
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.')
94+
parser.add_argument('--chunk_size', type=int, default=-1, help='Chunk size for prefix chunking during context processing (default: -1 = disabled, >0 = enabled)')
9395
parser.add_argument('--non-interactive', action=argparse.BooleanOptionalAction, required=False, default=False, help='Non-interactive mode, mainly for CI usage')
9496

9597
args = parser.parse_args()

src/config.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ struct Search_Element : JSON::Element {
868868
v_.length_penalty = static_cast<float>(JSON::Get<double>(value));
869869
} else if (name == "random_seed") {
870870
v_.random_seed = SafeDoubleToInt(JSON::Get<double>(value), name);
871+
} else if (name == "chunk_size") {
872+
v_.chunk_size = static_cast<int>(JSON::Get<double>(value));
871873
} else if (name == "do_sample") {
872874
v_.do_sample = JSON::Get<bool>(value);
873875
} else if (name == "past_present_share_buffer") {

src/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ struct Config {
273273
float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
274274
bool past_present_share_buffer{}; // The past/present kv tensors are shared and allocated once to max_length (cuda only)
275275
int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG
276+
int chunk_size{-1}; // Chunk size for prefix chunking during context processing. -1 = disabled, >0 = enabled with specified chunk size.
276277
} search;
277278

278279
void AddMapping(const std::string& nominal_name, const std::string& graph_name);

src/models/decoder_only.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ void DecoderOnly_State::SetExtraInputs(const std::vector<ExtraInput>& extra_inpu
2929

3030
DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices) {
3131
size_t num_tokens = next_tokens.size();
32-
const size_t chunk_size = 1024; // Experimental value
32+
const size_t chunk_size = static_cast<size_t>(model_.config_->search.chunk_size);
3333

34-
if (num_tokens > chunk_size) {
35-
// Chunking logic for context phase - process in chunks of 512 tokens
34+
if (chunk_size > 0 && num_tokens > chunk_size) {
35+
// Chunking logic for context phase - process in chunks based on configured chunk_size
3636
size_t processed_tokens = 0;
3737
int length = total_length - static_cast<int>(num_tokens);
3838
while (processed_tokens < num_tokens) {

0 commit comments

Comments
 (0)