Skip to content

Commit af8ffba

Browse files
eicherseijidjmmoss
authored andcommitted
Add PrefixRepetitionRandomDataset to vllm bench serve datasets (vllm-project#20638)
Signed-off-by: Seiji Eicher <[email protected]> Signed-off-by: Duncan Moss <[email protected]>
1 parent eec4da9 commit af8ffba

File tree

1 file changed

+131
-2
lines changed

1 file changed

+131
-2
lines changed

vllm/benchmarks/datasets.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727
from PIL import Image
2828
from transformers import PreTrainedTokenizerBase
29+
from typing_extensions import deprecated
2930

3031
from vllm.lora.request import LoRARequest
3132
from vllm.lora.utils import get_adapter_absolute_path
@@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
486487
"--dataset-name",
487488
type=str,
488489
default="random",
489-
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
490+
choices=[
491+
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
492+
"prefix_repetition"
493+
],
490494
help="Name of the dataset to benchmark on.",
491495
)
492496
parser.add_argument(
@@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
603607
"from the sampled HF dataset.",
604608
)
605609

610+
prefix_repetition_group = parser.add_argument_group(
611+
"prefix repetition dataset options")
612+
prefix_repetition_group.add_argument(
613+
"--prefix-repetition-prefix-len",
614+
type=int,
615+
default=256,
616+
help="Number of prefix tokens per request, used only for prefix "
617+
"repetition dataset.",
618+
)
619+
prefix_repetition_group.add_argument(
620+
"--prefix-repetition-suffix-len",
621+
type=int,
622+
default=256,
623+
help="Number of suffix tokens per request, used only for prefix "
624+
"repetition dataset. Total input length is prefix_len + suffix_len.",
625+
)
626+
prefix_repetition_group.add_argument(
627+
"--prefix-repetition-num-prefixes",
628+
type=int,
629+
default=10,
630+
help="Number of prefixes to generate, used only for prefix repetition "
631+
"dataset. Prompts per prefix is num_requests // num_prefixes.",
632+
)
633+
prefix_repetition_group.add_argument(
634+
"--prefix-repetition-output-len",
635+
type=int,
636+
default=128,
637+
help="Number of output tokens per request, used only for prefix "
638+
"repetition dataset.",
639+
)
640+
606641

607642
def get_samples(args, tokenizer) -> list[SampleRequest]:
608643
if args.dataset_name == "custom":
@@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
721756
output_len=args.random_output_len,
722757
range_ratio=args.random_range_ratio,
723758
),
759+
"prefix_repetition":
760+
lambda: PrefixRepetitionRandomDataset(
761+
random_seed=args.seed, dataset_path=args.dataset_path
762+
).sample(
763+
tokenizer=tokenizer,
764+
num_requests=args.num_prompts,
765+
prefix_len=args.prefix_repetition_prefix_len,
766+
suffix_len=args.prefix_repetition_suffix_len,
767+
num_prefixes=args.prefix_repetition_num_prefixes,
768+
output_len=args.prefix_repetition_output_len,
769+
),
724770
}
725771

726772
try:
@@ -828,7 +874,9 @@ def sample(
828874
# Sonnet Dataset Implementation
829875
# -----------------------------------------------------------------------------
830876

831-
877+
@deprecated(
878+
"SonnetDataset is deprecated and will be removed in a future version.",
879+
)
832880
class SonnetDataset(BenchmarkDataset):
833881
"""
834882
Simplified implementation of the Sonnet dataset. Loads poem lines from a
@@ -1537,3 +1585,84 @@ def sample(
15371585

15381586
self.maybe_oversample_requests(sampled_requests, num_requests)
15391587
return sampled_requests
1588+
1589+
1590+
# -----------------------------------------------------------------------------
1591+
# Prefix Repetition Dataset Implementation
1592+
# -----------------------------------------------------------------------------
1593+
1594+
1595+
class PrefixRepetitionRandomDataset(BenchmarkDataset):
1596+
# Default values copied from benchmark_serving.py for the repeated prefix
1597+
# dataset.
1598+
DEFAULT_PREFIX_LEN = 256
1599+
DEFAULT_SUFFIX_LEN = 256
1600+
DEFAULT_NUM_PREFIXES = 10
1601+
DEFAULT_OUTPUT_LEN = 128
1602+
1603+
def __init__(
1604+
self,
1605+
**kwargs,
1606+
) -> None:
1607+
super().__init__(**kwargs)
1608+
random.seed(self.random_seed)
1609+
np.random.seed(self.random_seed)
1610+
1611+
def sample(
1612+
self,
1613+
tokenizer: PreTrainedTokenizerBase,
1614+
num_requests: int,
1615+
prefix_len: int = DEFAULT_PREFIX_LEN,
1616+
suffix_len: int = DEFAULT_SUFFIX_LEN,
1617+
num_prefixes: int = DEFAULT_NUM_PREFIXES,
1618+
output_len: int = DEFAULT_OUTPUT_LEN,
1619+
**kwargs,
1620+
) -> list[SampleRequest]:
1621+
vocab_size = tokenizer.vocab_size
1622+
prompts_per_prefix = num_requests // num_prefixes
1623+
if prompts_per_prefix == 0:
1624+
raise ValueError(
1625+
f"num_requests ({num_requests}) must be greater than or equal "
1626+
f"to num_prefixes ({num_prefixes})"
1627+
)
1628+
1629+
def _generate_exact_length_tokens(target_length: int) -> list[int]:
1630+
"""Generate tokens that decode and re-encode to exactly
1631+
target_length."""
1632+
# Generate random tokens
1633+
tokens = np.random.randint(
1634+
0, vocab_size, size=target_length).tolist()
1635+
text = tokenizer.decode(tokens)
1636+
re_encoded = tokenizer.encode(text, add_special_tokens=False)
1637+
1638+
if len(re_encoded) == target_length:
1639+
return re_encoded
1640+
elif len(re_encoded) < target_length:
1641+
# Recursively generate additional consistent tokens
1642+
needed = target_length - len(re_encoded)
1643+
extra_tokens = _generate_exact_length_tokens(needed)
1644+
return re_encoded + extra_tokens
1645+
else:
1646+
# Truncate to target length
1647+
return re_encoded[:target_length]
1648+
1649+
requests = []
1650+
for _ in range(num_prefixes):
1651+
prefix_tokens = _generate_exact_length_tokens(prefix_len)
1652+
1653+
for _ in range(prompts_per_prefix):
1654+
suffix_tokens = _generate_exact_length_tokens(suffix_len)
1655+
1656+
combined_tokens = prefix_tokens + suffix_tokens
1657+
prompt = tokenizer.decode(combined_tokens)
1658+
prompt_len = len(combined_tokens)
1659+
requests.append(
1660+
SampleRequest(
1661+
prompt=prompt,
1662+
prompt_len=prompt_len,
1663+
expected_output_len=output_len,
1664+
)
1665+
)
1666+
1667+
random.shuffle(requests)
1668+
return requests

0 commit comments

Comments
 (0)