Skip to content

Commit 1c6d739

Browse files
ekagra-ranjanlk-chen
authored andcommitted
[Benchmark] Add single turn MTBench to Serving Bench (vllm-project#17202)
1 parent 54f7bab commit 1c6d739

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

benchmarks/benchmark_dataset.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,60 @@ def sample(self,
771771
return sampled_requests
772772

773773

774+
# -----------------------------------------------------------------------------
775+
# MT-Bench Dataset Implementation
776+
# -----------------------------------------------------------------------------
777+
778+
779+
class MTBenchDataset(HuggingFaceDataset):
780+
"""
781+
MT-Bench Dataset.
782+
https://huggingface.co/datasets/philschmid/mt-bench
783+
784+
We create a single turn dataset for MT-Bench.
785+
This is similar to Spec decoding benchmark setup in vLLM
786+
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
787+
""" # noqa: E501
788+
789+
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
790+
SUPPORTED_DATASET_PATHS = {
791+
"philschmid/mt-bench",
792+
}
793+
794+
def sample(self,
795+
tokenizer: PreTrainedTokenizerBase,
796+
num_requests: int,
797+
output_len: Optional[int] = None,
798+
enable_multimodal_chat: bool = False,
799+
**kwargs) -> list:
800+
output_len = (output_len
801+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
802+
sampled_requests = []
803+
804+
for item in self.data:
805+
if len(sampled_requests) >= num_requests:
806+
break
807+
prompt = item['turns'][0]
808+
809+
# apply template
810+
prompt = tokenizer.apply_chat_template([{
811+
"role": "user",
812+
"content": prompt
813+
}],
814+
add_generation_prompt=True,
815+
tokenize=False)
816+
817+
prompt_len = len(tokenizer(prompt).input_ids)
818+
sampled_requests.append(
819+
SampleRequest(
820+
prompt=prompt,
821+
prompt_len=prompt_len,
822+
expected_output_len=output_len,
823+
))
824+
self.maybe_oversample_requests(sampled_requests, num_requests)
825+
return sampled_requests
826+
827+
774828
# -----------------------------------------------------------------------------
775829
# AIMO Dataset Implementation
776830
# -----------------------------------------------------------------------------

benchmarks/benchmark_serving.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252

5353
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
5454
ConversationDataset, HuggingFaceDataset,
55-
InstructCoderDataset, RandomDataset,
56-
SampleRequest, ShareGPTDataset, SonnetDataset,
57-
VisionArenaDataset)
55+
InstructCoderDataset, MTBenchDataset,
56+
RandomDataset, SampleRequest, ShareGPTDataset,
57+
SonnetDataset, VisionArenaDataset)
5858
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
5959

6060
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -595,6 +595,9 @@ def main(args: argparse.Namespace):
595595
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
596596
dataset_class = InstructCoderDataset
597597
args.hf_split = "train"
598+
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
599+
dataset_class = MTBenchDataset
600+
args.hf_split = "train"
598601
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
599602
dataset_class = ConversationDataset
600603
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:

0 commit comments

Comments
 (0)