Skip to content

Commit 871c6b4

Browse files
authored
[None] [feat] skip batch_tokenize_prompts in CustomDataset (#10214)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
1 parent 522f1d2 commit 871c6b4

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorrt_llm/serve/scripts/benchmark_dataset.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ def sample(self, tokenizer: PreTrainedTokenizerBase,
728728
# Collect all prompts and metadata
729729
prompts = []
730730
max_tokens_list = []
731+
prompt_lengths = []
731732

732733
for i, entry in enumerate(self.data):
733734
if len(prompts) >= num_requests:
@@ -736,10 +737,18 @@ def sample(self, tokenizer: PreTrainedTokenizerBase,
736737
max_tokens = entry["input"]["max_tokens"]
737738
prompts.append(prompt)
738739
max_tokens_list.append(max_tokens)
740+
if "num_tokens" in entry["input"] and isinstance(
741+
entry["input"]["num_tokens"],
742+
int) and entry["input"]["num_tokens"] > 0:
743+
prompt_lengths.append(entry["input"]["num_tokens"])
739744

740-
# Use batch tokenization utility
741-
prompt_lengths, _ = batch_tokenize_prompts(
742-
prompts, tokenizer, progress_name="custom dataset prompts")
745+
if len(prompt_lengths) > 0 and len(prompt_lengths) == len(prompts):
746+
print(
747+
f"skipping batch tokenization because prompt_lengths are already available"
748+
)
749+
else:
750+
prompt_lengths, _ = batch_tokenize_prompts(
751+
prompts, tokenizer, progress_name="custom dataset prompts")
743752

744753
# Create SampleRequest objects
745754
samples = []

0 commit comments

Comments
 (0)