Skip to content

Commit 01700aa

Browse files
committed
Add rl pref reproduce script
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
1 parent 0a4c591 commit 01700aa

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import argparse
2+
import asyncio
3+
import json
4+
import os
5+
import time
6+
from pathlib import Path
7+
8+
import ray
9+
import torch
10+
from ray.util.placement_group import (
11+
PlacementGroupSchedulingStrategy,
12+
placement_group,
13+
remove_placement_group,
14+
)
15+
from transformers import AutoConfig
16+
17+
from tensorrt_llm import AsyncLLM
18+
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams
19+
20+
21+
@ray.remote
22+
class trtllm_instance:
23+
def __init__(self, async_llm_kwargs: dict, sampling_kwargs: dict):
24+
self.async_llm_kwargs = async_llm_kwargs
25+
self.sampling_kwargs = sampling_kwargs
26+
self.llm = None
27+
self.sampling_params = None
28+
29+
async def init_llm(self):
30+
self.llm = AsyncLLM(
31+
model=self.async_llm_kwargs["model"],
32+
backend="pytorch",
33+
orchestrator_type=self.async_llm_kwargs["orchestrator_type"],
34+
ray_worker_extension_cls=self.async_llm_kwargs["ray_worker_extension_cls"],
35+
kv_cache_config=KvCacheConfig(**self.async_llm_kwargs["kv_cache_config"]),
36+
cuda_graph_config=CudaGraphConfig(**self.async_llm_kwargs["cuda_graph_config"]),
37+
max_seq_len=self.async_llm_kwargs["max_seq_len"],
38+
max_batch_size=self.async_llm_kwargs["max_batch_size"],
39+
max_num_tokens=self.async_llm_kwargs["max_num_tokens"],
40+
tensor_parallel_size=self.async_llm_kwargs["tensor_parallel_size"],
41+
trust_remote_code=self.async_llm_kwargs["trust_remote_code"],
42+
enable_sleep=True,
43+
sampler_type=self.async_llm_kwargs["sampler_type"],
44+
placement_groups=self.async_llm_kwargs["placement_groups"],
45+
placement_bundle_indices=self.async_llm_kwargs["placement_bundle_indices"],
46+
per_worker_gpu_share=self.async_llm_kwargs["per_worker_gpu_share"],
47+
batch_wait_timeout_iters=32,
48+
batch_wait_max_tokens_ratio=0.5,
49+
)
50+
await self.llm.setup_async()
51+
self.sampling_params = SamplingParams(
52+
temperature=self.sampling_kwargs["temperature"],
53+
top_p=self.sampling_kwargs["top_p"],
54+
top_k=self.sampling_kwargs["top_k"],
55+
max_tokens=self.sampling_kwargs["max_tokens"],
56+
logprobs=self.sampling_kwargs["logprobs"],
57+
detokenize=self.sampling_kwargs["detokenize"],
58+
end_id=self.sampling_kwargs["end_id"],
59+
pad_id=self.sampling_kwargs["pad_id"],
60+
stop_token_ids=self.sampling_kwargs["stop_token_ids"],
61+
include_stop_str_in_output=self.sampling_kwargs["include_stop_str_in_output"],
62+
)
63+
64+
async def generate(self, prompt: list[int]):
65+
"""Generate for a single prompt"""
66+
outputs = await self.llm.generate_async(inputs=prompt, sampling_params=self.sampling_params)
67+
token_ids = outputs.outputs[0].token_ids
68+
log_probs = None
69+
if self.sampling_kwargs["logprobs"] is not None:
70+
log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs]
71+
return token_ids, log_probs
72+
73+
74+
async def setup_rl_llm(args):
75+
data_path = Path(args.data_path)
76+
with open(data_path, "r") as f:
77+
prompts = json.load(f)
78+
79+
hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
80+
81+
num_instances = args.num_instances
82+
num_gpus = args.tp_size * num_instances
83+
available_gpus = torch.cuda.device_count()
84+
if num_gpus > 8:
85+
raise ValueError(
86+
f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node."
87+
)
88+
if available_gpus < num_gpus:
89+
raise ValueError(
90+
f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})."
91+
)
92+
93+
os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1"
94+
runtime_env = {"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}
95+
pg = None
96+
97+
try:
98+
ray.init()
99+
pg = placement_group(
100+
[{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK"
101+
)
102+
103+
ray.get(pg.ready())
104+
105+
tp_size = args.tp_size
106+
placement_group_list = [[pg] for _ in range(num_instances)]
107+
placement_bundle_indices_list = [
108+
[list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances)
109+
]
110+
111+
llm_instances = []
112+
for i in range(num_instances):
113+
llm_instances.append(
114+
trtllm_instance.options(
115+
num_cpus=0,
116+
num_gpus=0,
117+
runtime_env=runtime_env,
118+
scheduling_strategy=PlacementGroupSchedulingStrategy(
119+
placement_group=pg,
120+
placement_group_capture_child_tasks=True,
121+
),
122+
).remote(
123+
async_llm_kwargs={
124+
"model": args.model_dir,
125+
"backend": "pytorch",
126+
"orchestrator_type": "ray",
127+
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
128+
"kv_cache_config": {
129+
"enable_block_reuse": args.enable_block_reuse,
130+
"free_gpu_memory_fraction": args.kv_cache_fraction,
131+
},
132+
"cuda_graph_config": {
133+
"enable_padding": args.enable_padding,
134+
"batch_sizes": args.batch_sizes,
135+
"max_batch_size": 0 if args.batch_sizes else args.max_batch_size,
136+
},
137+
"max_seq_len": args.max_seq_len,
138+
"max_batch_size": args.max_batch_size,
139+
"max_num_tokens": args.max_num_tokens,
140+
"tensor_parallel_size": args.tp_size,
141+
"trust_remote_code": args.trust_remote_code,
142+
"enable_sleep": True,
143+
"sampler_type": args.sampler_type,
144+
"placement_groups": placement_group_list[i],
145+
"placement_bundle_indices": placement_bundle_indices_list[i],
146+
"per_worker_gpu_share": 0.5,
147+
},
148+
sampling_kwargs={
149+
"temperature": args.temperature,
150+
"top_p": args.top_p,
151+
"top_k": args.top_k,
152+
"max_tokens": args.max_tokens,
153+
"logprobs": args.logprobs,
154+
"detokenize": False,
155+
"end_id": -1,
156+
"pad_id": hf_config.pad_token_id,
157+
"stop_token_ids": [hf_config.eos_token_id],
158+
"include_stop_str_in_output": True,
159+
},
160+
)
161+
)
162+
ray.get([llm.__ray_ready__.remote() for llm in llm_instances])
163+
ray.get([llm.init_llm.remote() for llm in llm_instances])
164+
165+
total_prompts = len(prompts)
166+
167+
print(
168+
f"Starting generation for {total_prompts} prompts across {num_instances} instances..."
169+
)
170+
start_time = time.time()
171+
172+
# Helper function to wrap Ray remote call as async coroutine
173+
async def generate_single_prompt(instance, prompt):
174+
"""Generate a single prompt asynchronously"""
175+
object_ref = instance.generate.remote(prompt=prompt)
176+
result = await asyncio.to_thread(ray.get, object_ref)
177+
return result
178+
179+
# Create tasks with round-robin distribution
180+
tasks = [
181+
generate_single_prompt(llm_instances[idx % num_instances], prompt)
182+
for idx, prompt in enumerate(prompts)
183+
]
184+
185+
results = await asyncio.gather(*tasks)
186+
end_time = time.time()
187+
188+
print(f"Time taken: {end_time - start_time:.2f} seconds")
189+
print(f"Total prompts: {total_prompts}")
190+
print(f"Throughput: {total_prompts / (end_time - start_time):.2f} prompts/sec")
191+
finally:
192+
if pg is not None:
193+
remove_placement_group(pg)
194+
ray.shutdown()
195+
196+
197+
def add_rl_llm_args(parser):
198+
parser.add_argument("--model_dir", type=str, required=True, help="Model checkpoint directory.")
199+
parser.add_argument("--data_path", type=str, required=True, help="Input data file path.")
200+
parser.add_argument(
201+
"--num_instances", type=int, required=True, help="Number of trtllm instances."
202+
)
203+
204+
# AsyncLLM parameters
205+
parser.add_argument("--tp_size", type=int, required=True, help="Tensor parallel size.")
206+
parser.add_argument("--max_seq_len", type=int, default=2048, help="Maximum sequence length.")
207+
parser.add_argument("--max_batch_size", type=int, default=384, help="Maximum batch size.")
208+
parser.add_argument(
209+
"--max_num_tokens", type=int, default=32768, help="Maximum number of tokens."
210+
)
211+
parser.add_argument(
212+
"--sampler_type",
213+
type=str,
214+
default="TRTLLMSampler",
215+
choices=["TRTLLMSampler", "TorchSampler"],
216+
help="Sampler type.",
217+
)
218+
parser.add_argument(
219+
"--trust_remote_code", type=bool, default=True, help="Whether to trust remote code."
220+
)
221+
222+
# KV Cache Config parameters
223+
parser.add_argument(
224+
"--kv_cache_fraction",
225+
type=float,
226+
default=0.6,
227+
help="The fraction of GPU memory to be used for KV cache.",
228+
)
229+
parser.add_argument(
230+
"--enable_block_reuse",
231+
type=bool,
232+
default=True,
233+
help="Whether to enable block reuse for KV cache.",
234+
)
235+
236+
# Cuda Graph Config parameters
237+
parser.add_argument(
238+
"--enable_padding",
239+
type=bool,
240+
default=True,
241+
help="Whether to enable padding for CUDA graphs.",
242+
)
243+
parser.add_argument(
244+
"--batch_sizes",
245+
type=int,
246+
nargs="+",
247+
default=None,
248+
help="The batch sizes to be used for CUDA graphs. Example: --batch_sizes 16 32 64 128 256",
249+
)
250+
251+
# Sampling parameters
252+
parser.add_argument("--max_tokens", type=int, default=1024)
253+
parser.add_argument("--temperature", type=float, default=1)
254+
parser.add_argument("--top_k", type=int, default=None)
255+
parser.add_argument("--top_p", type=float, default=None)
256+
parser.add_argument("--logprobs", type=int, default=None)
257+
258+
return parser
259+
260+
261+
def parse_arguments():
262+
parser = argparse.ArgumentParser(description="RL flow performance reproduction.")
263+
parser = add_rl_llm_args(parser)
264+
args = parser.parse_args()
265+
return args
266+
267+
268+
def main():
269+
args = parse_arguments()
270+
asyncio.run(setup_rl_llm(args))
271+
272+
273+
if __name__ == "__main__":
274+
main()

0 commit comments

Comments
 (0)