Skip to content

Commit 1cab2f9

Browse files
EAGLE 3: Fix preamble so that measured speedup over Eagle 1 becomes 32% instead of 5% on MTBench (vllm-project#25916)
Signed-off-by: Ekagra Ranjan <[email protected]>
1 parent 1e50f1b commit 1cab2f9

File tree

1 file changed

+39
-33
lines changed

1 file changed

+39
-33
lines changed

vllm/benchmarks/datasets.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11511151
help="Do not oversample if the dataset has " \
11521152
"fewer samples than num-prompts.",
11531153
)
1154+
parser.add_argument(
1155+
"--skip-chat-template",
1156+
action="store_true",
1157+
help=
1158+
"Skip applying chat template to prompt for datasets that support it.",
1159+
)
11541160

11551161
# group for dataset specific arguments
11561162
custom_group = parser.add_argument_group("custom dataset options")
@@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11611167
help=
11621168
"Number of output tokens per request, used only for custom dataset.",
11631169
)
1164-
custom_group.add_argument(
1165-
"--custom-skip-chat-template",
1166-
action="store_true",
1167-
help=
1168-
"Skip applying chat template to prompt, used only for custom dataset.",
1169-
)
11701170

11711171
spec_bench_group = parser.add_argument_group("spec bench dataset options")
11721172
spec_bench_group.add_argument(
@@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14351435
num_requests=args.num_prompts,
14361436
tokenizer=tokenizer,
14371437
output_len=args.custom_output_len,
1438-
skip_chat_template=args.custom_skip_chat_template,
1438+
skip_chat_template=args.skip_chat_template,
14391439
request_id_prefix=args.request_id_prefix,
14401440
no_oversample=args.no_oversample,
14411441
)
@@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
15761576
output_len=args.hf_output_len,
15771577
request_id_prefix=args.request_id_prefix,
15781578
no_oversample=args.no_oversample,
1579+
skip_chat_template=args.skip_chat_template,
15791580
**hf_kwargs
15801581
)
15811582

@@ -1815,7 +1816,6 @@ def load_data(self) -> None:
18151816

18161817
def sample(self, **kwargs) -> list:
18171818
# leverage CustomDataset sample
1818-
kwargs["skip_chat_template"] = False
18191819
return super().sample(**kwargs)
18201820

18211821

@@ -2221,6 +2221,7 @@ def sample(self,
22212221
num_requests: int,
22222222
output_len: Optional[int] = None,
22232223
enable_multimodal_chat: bool = False,
2224+
skip_chat_template: bool = False,
22242225
request_id_prefix: str = "",
22252226
no_oversample: bool = False,
22262227
**kwargs) -> list:
@@ -2236,14 +2237,15 @@ def sample(self,
22362237
)
22372238

22382239
# apply template
2239-
prompt = tokenizer.apply_chat_template(
2240-
[{
2241-
"role": "user",
2242-
"content": prompt
2243-
}],
2244-
add_generation_prompt=True,
2245-
tokenize=False,
2246-
)
2240+
if not skip_chat_template:
2241+
prompt = tokenizer.apply_chat_template(
2242+
[{
2243+
"role": "user",
2244+
"content": prompt
2245+
}],
2246+
add_generation_prompt=True,
2247+
tokenize=False,
2248+
)
22472249

22482250
prompt_len = len(tokenizer(prompt).input_ids)
22492251
sampled_requests.append(
@@ -2284,6 +2286,7 @@ def sample(
22842286
num_requests: int,
22852287
output_len: Optional[int] = None,
22862288
enable_multimodal_chat: bool = False,
2289+
skip_chat_template: bool = False,
22872290
request_id_prefix: str = "",
22882291
no_oversample: bool = False,
22892292
**kwargs,
@@ -2298,14 +2301,15 @@ def sample(
22982301
prompt = item["turns"][0]
22992302

23002303
# apply template
2301-
prompt = tokenizer.apply_chat_template(
2302-
[{
2303-
"role": "user",
2304-
"content": prompt
2305-
}],
2306-
add_generation_prompt=True,
2307-
tokenize=False,
2308-
)
2304+
if not skip_chat_template:
2305+
prompt = tokenizer.apply_chat_template(
2306+
[{
2307+
"role": "user",
2308+
"content": prompt
2309+
}],
2310+
add_generation_prompt=True,
2311+
tokenize=False,
2312+
)
23092313

23102314
prompt_len = len(tokenizer(prompt).input_ids)
23112315
sampled_requests.append(
@@ -2349,6 +2353,7 @@ def sample(
23492353
tokenizer: PreTrainedTokenizerBase,
23502354
num_requests: int,
23512355
output_len: Optional[int] = None,
2356+
skip_chat_template: bool = False,
23522357
request_id_prefix: str = "",
23532358
no_oversample: bool = False,
23542359
min_distance: float = 0.0,
@@ -2372,7 +2377,7 @@ def sample(
23722377

23732378
# template copied from
23742379
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
2375-
instruction = f"""Given a code file, please apply the change requests and generate the new file.
2380+
prompt = f"""Given a code file, please apply the change requests and generate the new file.
23762381
23772382
Original file:
23782383
```python
@@ -2385,14 +2390,15 @@ def sample(
23852390
Please generate the new code file in the "New file" section below.""" # noqa: E501
23862391

23872392
# apply template
2388-
prompt = tokenizer.apply_chat_template(
2389-
[{
2390-
"role": "user",
2391-
"content": instruction
2392-
}],
2393-
add_generation_prompt=True,
2394-
tokenize=False,
2395-
)
2393+
if not skip_chat_template:
2394+
prompt = tokenizer.apply_chat_template(
2395+
[{
2396+
"role": "user",
2397+
"content": prompt
2398+
}],
2399+
add_generation_prompt=True,
2400+
tokenize=False,
2401+
)
23962402

23972403
prompt_len = len(tokenizer(prompt).input_ids)
23982404

0 commit comments

Comments
 (0)