Skip to content

Commit 625815e

Browse files
authored
NVTesnorRtRtx: Support num_beam > 1 (microsoft#1688)
- Pass the num_beams though the overlay - max batch shapes for NVTensorRtRtx = batch_size * num_beams - Ading @baijumeswani @kunal-vaishnavi @gaugarg-nv for review
1 parent a3cc24d commit 625815e

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

examples/python/model-generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def main(args):
1919
batch_size = len(prompts)
2020

2121
config = og.Config(args.model_path)
22-
config.overlay(f'{{"search": {{"batch_size": {batch_size}}}}}')
22+
config.overlay(f'{{"search": {{"batch_size": {batch_size}, "num_beams": {3}}}}}')
2323

2424
if args.execution_provider != "follow_config":
2525
config.clear_providers()
@@ -45,7 +45,6 @@ def main(args):
4545
params = og.GeneratorParams(model)
4646

4747
search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
48-
search_options['num_beams'] = 3
4948

5049
if (args.verbose): print(f'Args: {args}')
5150
if (args.verbose): print(f'Search options: {search_options}')

src/models/model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ void ConfigureNvTensorRtRTxProfile(const Config& config, OrtSessionOptions& sess
352352
const int num_layers = config.model.decoder.num_hidden_layers;
353353
const int num_kv_heads = config.model.decoder.num_key_value_heads;
354354
const int head_dim = config.model.decoder.head_size;
355-
const int batch_size = config.search.batch_size;
355+
const int batch_size = config.search.batch_size * config.search.num_beams;
356356

357357
// Get max context length from config
358358
const int max_context_len = config.model.context_length;

0 commit comments

Comments
 (0)