|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import argparse |
| 17 | +import asyncio |
| 18 | + |
| 19 | +import yaml |
| 20 | +from specdec_bench import datasets, metrics, models, runners |
| 21 | +from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base |
| 22 | + |
| 23 | +engines_available = { |
| 24 | + "TRTLLM": models.TRTLLMPYTModel, |
| 25 | + "VLLM": models.VLLMModel, |
| 26 | + "SGLANG": models.SGLANGModel, |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10): |
| 31 | + """ |
| 32 | + Async version of run_loop with concurrency control using a semaphore. |
| 33 | +
|
| 34 | + Args: |
| 35 | + runner: The model runner instance |
| 36 | + dataset: The dataset containing requests |
| 37 | + tokenizer: The tokenizer instance |
| 38 | + output_length: Maximum output length |
| 39 | + concurrency: Maximum number of concurrent requests (default: 10) |
| 40 | + """ |
| 41 | + semaphore = asyncio.Semaphore(concurrency) |
| 42 | + max_length = output_length |
| 43 | + end_id = tokenizer.eos_token_id |
| 44 | + |
| 45 | + async def process_single_request(request, i): |
| 46 | + """Process a single request with all its conversation turns.""" |
| 47 | + async with semaphore: |
| 48 | + messages = [] |
| 49 | + if request.system_prompt is not None: |
| 50 | + messages.append({"role": "system", "content": request.system_prompt}) |
| 51 | + |
| 52 | + for question in request.turns: |
| 53 | + messages.append({"role": "user", "content": question}) |
| 54 | + entry_encoded = encode_chat(tokenizer, messages) |
| 55 | + |
| 56 | + # Run the async runner.run directly |
| 57 | + output_tokens = await runner.run(entry_encoded, max_length, end_id, i) |
| 58 | + output_text = decode_chat(tokenizer, output_tokens["output_ids"][0]) |
| 59 | + output_text = postprocess(output_text) |
| 60 | + messages.append({"role": "assistant", "content": output_text}) |
| 61 | + |
| 62 | + return messages |
| 63 | + |
| 64 | + tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)] |
| 65 | + text_outputs = await asyncio.gather(*tasks, return_exceptions=True) |
| 66 | + |
| 67 | + # Check for any exceptions and handle them |
| 68 | + for i, result in enumerate(text_outputs): |
| 69 | + if isinstance(result, Exception): |
| 70 | + print(f"Error processing request {i}: {result}") |
| 71 | + raise result |
| 72 | + |
| 73 | + runner.process_metrics_final(text_outputs) |
| 74 | + return text_outputs |
| 75 | + |
| 76 | + |
| 77 | +def run_simple(args): |
| 78 | + tokenizer = get_tokenizer(args.tokenizer) |
| 79 | + dataset_kwargs = args.runtime_params.get("dataset_kwargs", {}) |
| 80 | + if args.mtbench is not None: |
| 81 | + dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs) |
| 82 | + elif args.random_isl is not None: |
| 83 | + dataset = datasets.RandomToken( |
| 84 | + tokenizer, args.random_isl, args.num_requests, **dataset_kwargs |
| 85 | + ) |
| 86 | + engine_args = args.runtime_params.get("engine_args", {}) |
| 87 | + sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0}) |
| 88 | + model_class = engines_available[args.engine] |
| 89 | + model = model_class( |
| 90 | + args.model_dir, |
| 91 | + max_concurrent_requests=args.concurrency, |
| 92 | + sampling_kwargs=sampling_kwargs, |
| 93 | + speculative_algorithm=args.speculative_algorithm, |
| 94 | + draft_model_dir=args.draft_model_dir, |
| 95 | + speculative_num_steps=args.draft_length, |
| 96 | + tensor_parallel_size=args.tp_size, |
| 97 | + moe_expert_parallel_size=args.ep_size, |
| 98 | + **engine_args, |
| 99 | + ) |
| 100 | + |
| 101 | + metrics_list = [metrics.Timing(args.tp_size)] |
| 102 | + if args.aa_timing: |
| 103 | + metrics_list.append(metrics.AATiming(tokenizer)) |
| 104 | + if args.mtbench is not None: |
| 105 | + metrics_list.insert(0, metrics.MTBench()) |
| 106 | + else: |
| 107 | + metrics_list.insert(0, metrics.AcceptanceRate()) |
| 108 | + runner = runners.SimpleRunner(model, metrics=metrics_list) |
| 109 | + |
| 110 | + postprocess = postprocess_base |
| 111 | + |
| 112 | + asyncio.run( |
| 113 | + run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency) |
| 114 | + ) |
| 115 | + |
| 116 | + runner.clear_metrics() |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + parser = argparse.ArgumentParser() |
| 121 | + parser.add_argument( |
| 122 | + "--tokenizer", type=str, required=True, help="Path to the tokenizer directory" |
| 123 | + ) |
| 124 | + parser.add_argument( |
| 125 | + "--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset" |
| 126 | + ) |
| 127 | + parser.add_argument( |
| 128 | + "--random_isl", |
| 129 | + type=int, |
| 130 | + required=False, |
| 131 | + default=None, |
| 132 | + help="How many tokens random input should be.", |
| 133 | + ) |
| 134 | + parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run") |
| 135 | + parser.add_argument( |
| 136 | + "--engine", |
| 137 | + type=str, |
| 138 | + required=False, |
| 139 | + default="TRTLLM", |
| 140 | + choices=list(engines_available.keys()), |
| 141 | + help="Engine to use", |
| 142 | + ) |
| 143 | + parser.add_argument( |
| 144 | + "--speculative_algorithm", |
| 145 | + type=str, |
| 146 | + required=False, |
| 147 | + default="EAGLE3", |
| 148 | + choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"], |
| 149 | + help="Speculative algorithm to use", |
| 150 | + ) |
| 151 | + parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory") |
| 152 | + parser.add_argument( |
| 153 | + "--draft_model_dir", |
| 154 | + type=str, |
| 155 | + required=False, |
| 156 | + default=None, |
| 157 | + help="Path to the draft model directory", |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--runtime_params", |
| 161 | + type=str, |
| 162 | + required=False, |
| 163 | + default=None, |
| 164 | + help="Path to the runtime params yaml file", |
| 165 | + ) |
| 166 | + parser.add_argument( |
| 167 | + "--output_length", type=int, required=False, default=4096, help="Output length" |
| 168 | + ) |
| 169 | + parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length") |
| 170 | + parser.add_argument( |
| 171 | + "--tp_size", type=int, required=False, default=4, help="Tensor parallel size" |
| 172 | + ) |
| 173 | + parser.add_argument( |
| 174 | + "--ep_size", type=int, required=False, default=2, help="Expert parallel size" |
| 175 | + ) |
| 176 | + parser.add_argument( |
| 177 | + "--concurrency", |
| 178 | + type=int, |
| 179 | + required=False, |
| 180 | + default=1, |
| 181 | + help="Maximum number of concurrent requests", |
| 182 | + ) |
| 183 | + parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric") |
| 184 | + args = parser.parse_args() |
| 185 | + |
| 186 | + if args.runtime_params is not None: |
| 187 | + with open(args.runtime_params) as f: |
| 188 | + args.runtime_params = yaml.safe_load(f) |
| 189 | + else: |
| 190 | + args.runtime_params = {} |
| 191 | + |
| 192 | + assert args.mtbench is not None or args.random_isl is not None, ( |
| 193 | + "Either mtbench or random_isl must be provided" |
| 194 | + ) |
| 195 | + |
| 196 | + run_simple(args) |
0 commit comments