Skip to content

Commit b45cad7

Browse files
committed
ray + trtllm works
1 parent d178858 commit b45cad7

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
a simple demonstration of RLHF with vLLM, inspired by
4+
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
5+
It follows the design that, training processes and inference processes
6+
are different, and they live on different GPUs.
7+
Training processes send prompts to inference processes to generate data,
8+
and also synchronize the weights of the model by broadcasting the weights
9+
from the training process to the inference process.
10+
Note that this is a simple demonstration of one training instance and one
11+
inference instance. In practice, there could be multiple training instances
12+
and multiple inference instances. For the full implementation, please refer
13+
to the OpenRLHF framework.
14+
"""
15+
import os
16+
17+
import ray
18+
import torch
19+
from ray.util.placement_group import placement_group
20+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
21+
from transformers import AutoModelForCausalLM
22+
23+
from tensorrt_llm import LLM
24+
25+
26+
class RayLLM:
27+
def __init__(self, *args, **kwargs):
28+
import torch
29+
dev_count = torch.cuda.device_count()
30+
print("dev_count: ", dev_count)
31+
self.mpi_session = None
32+
for i in range(dev_count):
33+
device = torch.cuda.get_device_properties(i)
34+
print(f"pid: {os.getpid()}, device {i}: {device.name}, UUID: {device.uuid}, Memory: {device.total_memory / 1024**2:.0f}MB")
35+
self.llm = LLM(*args, **kwargs)
36+
37+
def generate(self, prompts):
38+
ret = []
39+
40+
llm_ret = self.llm.generate(prompts)
41+
for r in llm_ret:
42+
ret.append(r.outputs[0].text)
43+
44+
return ret
45+
46+
"""
47+
Start the training process, here we use huggingface transformers
48+
as an example to hold a model on GPU 0.
49+
"""
50+
51+
52+
"""
53+
Start the inference process, here we use vLLM to hold a model on GPU 1 and
54+
GPU 2. For the details on how to use ray, please refer to the ray
55+
documentation https://docs.ray.io/en/latest/ .
56+
"""
57+
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
58+
ray.init(include_dashboard=False)
59+
60+
pg_inference = placement_group([{"GPU": 2, "CPU": 0}])
61+
ray.get(pg_inference.ready())
62+
scheduling_inference = PlacementGroupSchedulingStrategy(
63+
placement_group=pg_inference,
64+
placement_group_capture_child_tasks=True,
65+
placement_group_bundle_index=0,
66+
)
67+
"""
68+
launch the vLLM inference engine.
69+
here we use `enforce_eager` to reduce the start time.
70+
"""
71+
llm = ray.remote(
72+
num_cpus=0,
73+
num_gpus=2,
74+
scheduling_strategy=scheduling_inference,
75+
)(RayLLM).remote(
76+
model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
77+
tensor_parallel_size=2,
78+
# enforce_eager=True,
79+
# worker_extension_cls="rlhf_utils.WorkerExtension",
80+
# tensor_parallel_size=2,
81+
# distributed_executor_backend="ray",
82+
)
83+
84+
# Generate texts from the prompts.
85+
prompts = [
86+
"Hello, my name is",
87+
"The president of the United States is",
88+
"The capital of France is",
89+
"The future of AI is",
90+
]
91+
92+
#sampling_params = SamplingParams(temperature=0)
93+
94+
outputs = ray.get(llm.generate.remote(prompts))
95+
96+
for prompt, output in zip(prompts, outputs):
97+
print(f"Prompt: {prompt!r}, "
98+
f"Generated text: {output!r}")
99+
100+
# set up the communication between the training process
101+
# and the inference engine.
102+
# master_address = get_ip()
103+
# master_port = get_open_port()
104+
105+
# handle = llm.collective_rpc.remote("init_weight_update_group",
106+
# args=('master_address', 5, 0, 3))
107+
108+
# # model_update_group = stateless_init_process_group(master_address, master_port,
109+
# # 0, 3, torch.device("cuda:0"))
110+
# # ray.get(handle)
111+
112+
# # simulate training, modify the weights of the model.
113+
# # for name, p in train_model.named_parameters():
114+
# # p.data.zero_()
115+
116+
# # # sync weight from the training process to the inference engine.
117+
# # for name, p in train_model.named_parameters():
118+
# # handle = llm.collective_rpc.remote("update_weight",
119+
# # args=(name, p.dtype, p.shape))
120+
# # model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
121+
# # ray.get(handle)
122+
123+
# # check if the weights are updated.
124+
# # assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
125+
126+
# # use the updated model to generate texts, they will be nonsense
127+
# # because the weights are all zeros.
128+
# outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
129+
# for output in outputs_updated:
130+
# prompt = output.prompt
131+
# generated_text = output.outputs[0].text
132+
# print(f"Prompt: {prompt!r}, "
133+
# f"Generated text: {generated_text!r}")

0 commit comments

Comments
 (0)