|
| 1 | +import argparse |
| 2 | +import asyncio |
| 3 | +import json |
| 4 | +import os |
| 5 | +import time |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import ray |
| 9 | +import torch |
| 10 | +from ray.util.placement_group import ( |
| 11 | + PlacementGroupSchedulingStrategy, |
| 12 | + placement_group, |
| 13 | + remove_placement_group, |
| 14 | +) |
| 15 | +from transformers import AutoConfig |
| 16 | + |
| 17 | +from tensorrt_llm import AsyncLLM |
| 18 | +from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams |
| 19 | + |
| 20 | + |
| 21 | +@ray.remote |
| 22 | +class trtllm_instance: |
| 23 | + def __init__(self, async_llm_kwargs: dict, sampling_kwargs: dict): |
| 24 | + self.async_llm_kwargs = async_llm_kwargs |
| 25 | + self.sampling_kwargs = sampling_kwargs |
| 26 | + self.llm = None |
| 27 | + self.sampling_params = None |
| 28 | + |
| 29 | + async def init_llm(self): |
| 30 | + self.llm = AsyncLLM( |
| 31 | + model=self.async_llm_kwargs["model"], |
| 32 | + backend="pytorch", |
| 33 | + orchestrator_type=self.async_llm_kwargs["orchestrator_type"], |
| 34 | + ray_worker_extension_cls=self.async_llm_kwargs["ray_worker_extension_cls"], |
| 35 | + kv_cache_config=KvCacheConfig(**self.async_llm_kwargs["kv_cache_config"]), |
| 36 | + cuda_graph_config=CudaGraphConfig(**self.async_llm_kwargs["cuda_graph_config"]), |
| 37 | + max_seq_len=self.async_llm_kwargs["max_seq_len"], |
| 38 | + max_batch_size=self.async_llm_kwargs["max_batch_size"], |
| 39 | + max_num_tokens=self.async_llm_kwargs["max_num_tokens"], |
| 40 | + tensor_parallel_size=self.async_llm_kwargs["tensor_parallel_size"], |
| 41 | + trust_remote_code=self.async_llm_kwargs["trust_remote_code"], |
| 42 | + enable_sleep=True, |
| 43 | + sampler_type=self.async_llm_kwargs["sampler_type"], |
| 44 | + placement_groups=self.async_llm_kwargs["placement_groups"], |
| 45 | + placement_bundle_indices=self.async_llm_kwargs["placement_bundle_indices"], |
| 46 | + per_worker_gpu_share=self.async_llm_kwargs["per_worker_gpu_share"], |
| 47 | + batch_wait_timeout_iters=32, |
| 48 | + batch_wait_max_tokens_ratio=0.5, |
| 49 | + ) |
| 50 | + await self.llm.setup_async() |
| 51 | + self.sampling_params = SamplingParams( |
| 52 | + temperature=self.sampling_kwargs["temperature"], |
| 53 | + top_p=self.sampling_kwargs["top_p"], |
| 54 | + top_k=self.sampling_kwargs["top_k"], |
| 55 | + max_tokens=self.sampling_kwargs["max_tokens"], |
| 56 | + logprobs=self.sampling_kwargs["logprobs"], |
| 57 | + detokenize=self.sampling_kwargs["detokenize"], |
| 58 | + end_id=self.sampling_kwargs["end_id"], |
| 59 | + pad_id=self.sampling_kwargs["pad_id"], |
| 60 | + stop_token_ids=self.sampling_kwargs["stop_token_ids"], |
| 61 | + include_stop_str_in_output=self.sampling_kwargs["include_stop_str_in_output"], |
| 62 | + ) |
| 63 | + |
| 64 | + async def generate(self, prompt: list[int]): |
| 65 | + """Generate for a single prompt""" |
| 66 | + outputs = await self.llm.generate_async(inputs=prompt, sampling_params=self.sampling_params) |
| 67 | + token_ids = outputs.outputs[0].token_ids |
| 68 | + log_probs = None |
| 69 | + if self.sampling_kwargs["logprobs"] is not None: |
| 70 | + log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] |
| 71 | + return token_ids, log_probs |
| 72 | + |
| 73 | + |
| 74 | +async def setup_rl_llm(args): |
| 75 | + data_path = Path(args.data_path) |
| 76 | + with open(data_path, "r") as f: |
| 77 | + prompts = json.load(f) |
| 78 | + |
| 79 | + hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) |
| 80 | + |
| 81 | + num_instances = args.num_instances |
| 82 | + num_gpus = args.tp_size * num_instances |
| 83 | + available_gpus = torch.cuda.device_count() |
| 84 | + if num_gpus > 8: |
| 85 | + raise ValueError( |
| 86 | + f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node." |
| 87 | + ) |
| 88 | + if available_gpus < num_gpus: |
| 89 | + raise ValueError( |
| 90 | + f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})." |
| 91 | + ) |
| 92 | + |
| 93 | + os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" |
| 94 | + runtime_env = {"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}} |
| 95 | + pg = None |
| 96 | + |
| 97 | + try: |
| 98 | + ray.init() |
| 99 | + pg = placement_group( |
| 100 | + [{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK" |
| 101 | + ) |
| 102 | + |
| 103 | + ray.get(pg.ready()) |
| 104 | + |
| 105 | + tp_size = args.tp_size |
| 106 | + placement_group_list = [[pg] for _ in range(num_instances)] |
| 107 | + placement_bundle_indices_list = [ |
| 108 | + [list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances) |
| 109 | + ] |
| 110 | + |
| 111 | + llm_instances = [] |
| 112 | + for i in range(num_instances): |
| 113 | + llm_instances.append( |
| 114 | + trtllm_instance.options( |
| 115 | + num_cpus=0, |
| 116 | + num_gpus=0, |
| 117 | + runtime_env=runtime_env, |
| 118 | + scheduling_strategy=PlacementGroupSchedulingStrategy( |
| 119 | + placement_group=pg, |
| 120 | + placement_group_capture_child_tasks=True, |
| 121 | + ), |
| 122 | + ).remote( |
| 123 | + async_llm_kwargs={ |
| 124 | + "model": args.model_dir, |
| 125 | + "backend": "pytorch", |
| 126 | + "orchestrator_type": "ray", |
| 127 | + "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", |
| 128 | + "kv_cache_config": { |
| 129 | + "enable_block_reuse": args.enable_block_reuse, |
| 130 | + "free_gpu_memory_fraction": args.kv_cache_fraction, |
| 131 | + }, |
| 132 | + "cuda_graph_config": { |
| 133 | + "enable_padding": args.enable_padding, |
| 134 | + "batch_sizes": args.batch_sizes, |
| 135 | + "max_batch_size": 0 if args.batch_sizes else args.max_batch_size, |
| 136 | + }, |
| 137 | + "max_seq_len": args.max_seq_len, |
| 138 | + "max_batch_size": args.max_batch_size, |
| 139 | + "max_num_tokens": args.max_num_tokens, |
| 140 | + "tensor_parallel_size": args.tp_size, |
| 141 | + "trust_remote_code": args.trust_remote_code, |
| 142 | + "enable_sleep": True, |
| 143 | + "sampler_type": args.sampler_type, |
| 144 | + "placement_groups": placement_group_list[i], |
| 145 | + "placement_bundle_indices": placement_bundle_indices_list[i], |
| 146 | + "per_worker_gpu_share": 0.5, |
| 147 | + }, |
| 148 | + sampling_kwargs={ |
| 149 | + "temperature": args.temperature, |
| 150 | + "top_p": args.top_p, |
| 151 | + "top_k": args.top_k, |
| 152 | + "max_tokens": args.max_tokens, |
| 153 | + "logprobs": args.logprobs, |
| 154 | + "detokenize": False, |
| 155 | + "end_id": -1, |
| 156 | + "pad_id": hf_config.pad_token_id, |
| 157 | + "stop_token_ids": [hf_config.eos_token_id], |
| 158 | + "include_stop_str_in_output": True, |
| 159 | + }, |
| 160 | + ) |
| 161 | + ) |
| 162 | + ray.get([llm.__ray_ready__.remote() for llm in llm_instances]) |
| 163 | + ray.get([llm.init_llm.remote() for llm in llm_instances]) |
| 164 | + |
| 165 | + total_prompts = len(prompts) |
| 166 | + |
| 167 | + print( |
| 168 | + f"Starting generation for {total_prompts} prompts across {num_instances} instances..." |
| 169 | + ) |
| 170 | + start_time = time.time() |
| 171 | + |
| 172 | + # Helper function to wrap Ray remote call as async coroutine |
| 173 | + async def generate_single_prompt(instance, prompt): |
| 174 | + """Generate a single prompt asynchronously""" |
| 175 | + object_ref = instance.generate.remote(prompt=prompt) |
| 176 | + result = await asyncio.to_thread(ray.get, object_ref) |
| 177 | + return result |
| 178 | + |
| 179 | + # Create tasks with round-robin distribution |
| 180 | + tasks = [ |
| 181 | + generate_single_prompt(llm_instances[idx % num_instances], prompt) |
| 182 | + for idx, prompt in enumerate(prompts) |
| 183 | + ] |
| 184 | + |
| 185 | + results = await asyncio.gather(*tasks) |
| 186 | + end_time = time.time() |
| 187 | + |
| 188 | + print(f"Time taken: {end_time - start_time:.2f} seconds") |
| 189 | + print(f"Total prompts: {total_prompts}") |
| 190 | + print(f"Throughput: {total_prompts / (end_time - start_time):.2f} prompts/sec") |
| 191 | + finally: |
| 192 | + if pg is not None: |
| 193 | + remove_placement_group(pg) |
| 194 | + ray.shutdown() |
| 195 | + |
| 196 | + |
| 197 | +def add_rl_llm_args(parser): |
| 198 | + parser.add_argument("--model_dir", type=str, required=True, help="Model checkpoint directory.") |
| 199 | + parser.add_argument("--data_path", type=str, required=True, help="Input data file path.") |
| 200 | + parser.add_argument( |
| 201 | + "--num_instances", type=int, required=True, help="Number of trtllm instances." |
| 202 | + ) |
| 203 | + |
| 204 | + # AsyncLLM parameters |
| 205 | + parser.add_argument("--tp_size", type=int, required=True, help="Tensor parallel size.") |
| 206 | + parser.add_argument("--max_seq_len", type=int, default=2048, help="Maximum sequence length.") |
| 207 | + parser.add_argument("--max_batch_size", type=int, default=384, help="Maximum batch size.") |
| 208 | + parser.add_argument( |
| 209 | + "--max_num_tokens", type=int, default=32768, help="Maximum number of tokens." |
| 210 | + ) |
| 211 | + parser.add_argument( |
| 212 | + "--sampler_type", |
| 213 | + type=str, |
| 214 | + default="TRTLLMSampler", |
| 215 | + choices=["TRTLLMSampler", "TorchSampler"], |
| 216 | + help="Sampler type.", |
| 217 | + ) |
| 218 | + parser.add_argument( |
| 219 | + "--trust_remote_code", type=bool, default=True, help="Whether to trust remote code." |
| 220 | + ) |
| 221 | + |
| 222 | + # KV Cache Config parameters |
| 223 | + parser.add_argument( |
| 224 | + "--kv_cache_fraction", |
| 225 | + type=float, |
| 226 | + default=0.6, |
| 227 | + help="The fraction of GPU memory to be used for KV cache.", |
| 228 | + ) |
| 229 | + parser.add_argument( |
| 230 | + "--enable_block_reuse", |
| 231 | + type=bool, |
| 232 | + default=True, |
| 233 | + help="Whether to enable block reuse for KV cache.", |
| 234 | + ) |
| 235 | + |
| 236 | + # Cuda Graph Config parameters |
| 237 | + parser.add_argument( |
| 238 | + "--enable_padding", |
| 239 | + type=bool, |
| 240 | + default=True, |
| 241 | + help="Whether to enable padding for CUDA graphs.", |
| 242 | + ) |
| 243 | + parser.add_argument( |
| 244 | + "--batch_sizes", |
| 245 | + type=int, |
| 246 | + nargs="+", |
| 247 | + default=None, |
| 248 | + help="The batch sizes to be used for CUDA graphs. Example: --batch_sizes 16 32 64 128 256", |
| 249 | + ) |
| 250 | + |
| 251 | + # Sampling parameters |
| 252 | + parser.add_argument("--max_tokens", type=int, default=1024) |
| 253 | + parser.add_argument("--temperature", type=float, default=1) |
| 254 | + parser.add_argument("--top_k", type=int, default=None) |
| 255 | + parser.add_argument("--top_p", type=float, default=None) |
| 256 | + parser.add_argument("--logprobs", type=int, default=None) |
| 257 | + |
| 258 | + return parser |
| 259 | + |
| 260 | + |
| 261 | +def parse_arguments(): |
| 262 | + parser = argparse.ArgumentParser(description="RL flow performance reproduction.") |
| 263 | + parser = add_rl_llm_args(parser) |
| 264 | + args = parser.parse_args() |
| 265 | + return args |
| 266 | + |
| 267 | + |
| 268 | +def main(): |
| 269 | + args = parse_arguments() |
| 270 | + asyncio.run(setup_rl_llm(args)) |
| 271 | + |
| 272 | + |
| 273 | +if __name__ == "__main__": |
| 274 | + main() |
0 commit comments