|
26 | 26 | import numpy as np
|
27 | 27 | from PIL import Image
|
28 | 28 | from transformers import PreTrainedTokenizerBase
|
| 29 | +from typing_extensions import deprecated |
29 | 30 |
|
30 | 31 | from vllm.lora.request import LoRARequest
|
31 | 32 | from vllm.lora.utils import get_adapter_absolute_path
|
@@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
486 | 487 | "--dataset-name",
|
487 | 488 | type=str,
|
488 | 489 | default="random",
|
489 |
| - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], |
| 490 | + choices=[ |
| 491 | + "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", |
| 492 | + "prefix_repetition" |
| 493 | + ], |
490 | 494 | help="Name of the dataset to benchmark on.",
|
491 | 495 | )
|
492 | 496 | parser.add_argument(
|
@@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
603 | 607 | "from the sampled HF dataset.",
|
604 | 608 | )
|
605 | 609 |
|
| 610 | + prefix_repetition_group = parser.add_argument_group( |
| 611 | + "prefix repetition dataset options") |
| 612 | + prefix_repetition_group.add_argument( |
| 613 | + "--prefix-repetition-prefix-len", |
| 614 | + type=int, |
| 615 | + default=256, |
| 616 | + help="Number of prefix tokens per request, used only for prefix " |
| 617 | + "repetition dataset.", |
| 618 | + ) |
| 619 | + prefix_repetition_group.add_argument( |
| 620 | + "--prefix-repetition-suffix-len", |
| 621 | + type=int, |
| 622 | + default=256, |
| 623 | + help="Number of suffix tokens per request, used only for prefix " |
| 624 | + "repetition dataset. Total input length is prefix_len + suffix_len.", |
| 625 | + ) |
| 626 | + prefix_repetition_group.add_argument( |
| 627 | + "--prefix-repetition-num-prefixes", |
| 628 | + type=int, |
| 629 | + default=10, |
| 630 | + help="Number of prefixes to generate, used only for prefix repetition " |
| 631 | + "dataset. Prompts per prefix is num_requests // num_prefixes.", |
| 632 | + ) |
| 633 | + prefix_repetition_group.add_argument( |
| 634 | + "--prefix-repetition-output-len", |
| 635 | + type=int, |
| 636 | + default=128, |
| 637 | + help="Number of output tokens per request, used only for prefix " |
| 638 | + "repetition dataset.", |
| 639 | + ) |
| 640 | + |
606 | 641 |
|
607 | 642 | def get_samples(args, tokenizer) -> list[SampleRequest]:
|
608 | 643 | if args.dataset_name == "custom":
|
@@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
721 | 756 | output_len=args.random_output_len,
|
722 | 757 | range_ratio=args.random_range_ratio,
|
723 | 758 | ),
|
| 759 | + "prefix_repetition": |
| 760 | + lambda: PrefixRepetitionRandomDataset( |
| 761 | + random_seed=args.seed, dataset_path=args.dataset_path |
| 762 | + ).sample( |
| 763 | + tokenizer=tokenizer, |
| 764 | + num_requests=args.num_prompts, |
| 765 | + prefix_len=args.prefix_repetition_prefix_len, |
| 766 | + suffix_len=args.prefix_repetition_suffix_len, |
| 767 | + num_prefixes=args.prefix_repetition_num_prefixes, |
| 768 | + output_len=args.prefix_repetition_output_len, |
| 769 | + ), |
724 | 770 | }
|
725 | 771 |
|
726 | 772 | try:
|
@@ -828,7 +874,9 @@ def sample(
|
828 | 874 | # Sonnet Dataset Implementation
|
829 | 875 | # -----------------------------------------------------------------------------
|
830 | 876 |
|
831 |
| - |
| 877 | +@deprecated( |
| 878 | + "SonnetDataset is deprecated and will be removed in a future version.", |
| 879 | +) |
832 | 880 | class SonnetDataset(BenchmarkDataset):
|
833 | 881 | """
|
834 | 882 | Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
@@ -1537,3 +1585,84 @@ def sample(
|
1537 | 1585 |
|
1538 | 1586 | self.maybe_oversample_requests(sampled_requests, num_requests)
|
1539 | 1587 | return sampled_requests
|
| 1588 | + |
| 1589 | + |
| 1590 | +# ----------------------------------------------------------------------------- |
| 1591 | +# Prefix Repetition Dataset Implementation |
| 1592 | +# ----------------------------------------------------------------------------- |
| 1593 | + |
| 1594 | + |
| 1595 | +class PrefixRepetitionRandomDataset(BenchmarkDataset): |
| 1596 | + # Default values copied from benchmark_serving.py for the repeated prefix |
| 1597 | + # dataset. |
| 1598 | + DEFAULT_PREFIX_LEN = 256 |
| 1599 | + DEFAULT_SUFFIX_LEN = 256 |
| 1600 | + DEFAULT_NUM_PREFIXES = 10 |
| 1601 | + DEFAULT_OUTPUT_LEN = 128 |
| 1602 | + |
| 1603 | + def __init__( |
| 1604 | + self, |
| 1605 | + **kwargs, |
| 1606 | + ) -> None: |
| 1607 | + super().__init__(**kwargs) |
| 1608 | + random.seed(self.random_seed) |
| 1609 | + np.random.seed(self.random_seed) |
| 1610 | + |
| 1611 | + def sample( |
| 1612 | + self, |
| 1613 | + tokenizer: PreTrainedTokenizerBase, |
| 1614 | + num_requests: int, |
| 1615 | + prefix_len: int = DEFAULT_PREFIX_LEN, |
| 1616 | + suffix_len: int = DEFAULT_SUFFIX_LEN, |
| 1617 | + num_prefixes: int = DEFAULT_NUM_PREFIXES, |
| 1618 | + output_len: int = DEFAULT_OUTPUT_LEN, |
| 1619 | + **kwargs, |
| 1620 | + ) -> list[SampleRequest]: |
| 1621 | + vocab_size = tokenizer.vocab_size |
| 1622 | + prompts_per_prefix = num_requests // num_prefixes |
| 1623 | + if prompts_per_prefix == 0: |
| 1624 | + raise ValueError( |
| 1625 | + f"num_requests ({num_requests}) must be greater than or equal " |
| 1626 | + f"to num_prefixes ({num_prefixes})" |
| 1627 | + ) |
| 1628 | + |
| 1629 | + def _generate_exact_length_tokens(target_length: int) -> list[int]: |
| 1630 | + """Generate tokens that decode and re-encode to exactly |
| 1631 | + target_length.""" |
| 1632 | + # Generate random tokens |
| 1633 | + tokens = np.random.randint( |
| 1634 | + 0, vocab_size, size=target_length).tolist() |
| 1635 | + text = tokenizer.decode(tokens) |
| 1636 | + re_encoded = tokenizer.encode(text, add_special_tokens=False) |
| 1637 | + |
| 1638 | + if len(re_encoded) == target_length: |
| 1639 | + return re_encoded |
| 1640 | + elif len(re_encoded) < target_length: |
| 1641 | + # Recursively generate additional consistent tokens |
| 1642 | + needed = target_length - len(re_encoded) |
| 1643 | + extra_tokens = _generate_exact_length_tokens(needed) |
| 1644 | + return re_encoded + extra_tokens |
| 1645 | + else: |
| 1646 | + # Truncate to target length |
| 1647 | + return re_encoded[:target_length] |
| 1648 | + |
| 1649 | + requests = [] |
| 1650 | + for _ in range(num_prefixes): |
| 1651 | + prefix_tokens = _generate_exact_length_tokens(prefix_len) |
| 1652 | + |
| 1653 | + for _ in range(prompts_per_prefix): |
| 1654 | + suffix_tokens = _generate_exact_length_tokens(suffix_len) |
| 1655 | + |
| 1656 | + combined_tokens = prefix_tokens + suffix_tokens |
| 1657 | + prompt = tokenizer.decode(combined_tokens) |
| 1658 | + prompt_len = len(combined_tokens) |
| 1659 | + requests.append( |
| 1660 | + SampleRequest( |
| 1661 | + prompt=prompt, |
| 1662 | + prompt_len=prompt_len, |
| 1663 | + expected_output_len=output_len, |
| 1664 | + ) |
| 1665 | + ) |
| 1666 | + |
| 1667 | + random.shuffle(requests) |
| 1668 | + return requests |
0 commit comments