Skip to content

Commit c9ffa2d

Browse files
weixiao-huangyoukaichao
authored andcommitted
[RL] fast weight update with zmq + ipc handles (vllm-project#24295)
Signed-off-by: huangweixiao <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 33fa7b8 commit c9ffa2d

File tree

2 files changed

+152
-33
lines changed

2 files changed

+152
-33
lines changed

examples/offline_inference/rlhf_colocate.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
https://docs.ray.io/en/latest/placement-groups.html
2929
"""
3030

31+
import gc
3132
import os
3233

3334
import ray
3435
import torch
36+
import zmq
3537
from ray.util.placement_group import placement_group
3638
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
39+
from torch.multiprocessing.reductions import reduce_tensor
3740

3841
from vllm import LLM
3942

@@ -86,20 +89,72 @@ def __init__(self):
8689
from vllm.platforms import current_platform
8790

8891
self.device_uuid = current_platform.get_device_uuid(0)
92+
self.zmq_context = zmq.Context()
93+
self.zmq_address_counter = 0
94+
self.zmq_handle = None
8995

9096
def report_device_id(self) -> str:
9197
return self.device_uuid
9298

93-
def get_weight_ipc_handles(self):
94-
from torch.multiprocessing.reductions import reduce_tensor
99+
def get_zmq_handles(self) -> dict[str, str]:
100+
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
101+
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
102+
self.zmq_address_counter += 1
103+
return {self.device_uuid: self.zmq_handle}
95104

96-
data = {}
97-
for name, p in self.model.named_parameters():
98-
# A training actor might hold only a subset of the weights and may
99-
# need to gather weights from other actors. For demonstration
100-
# purposes, each training actor owns the full weight set.
101-
data[name] = reduce_tensor(p.detach())
102-
return {self.device_uuid: data}
105+
def update_weights(self):
106+
# align size to avoid misaligned address
107+
align_size = 256
108+
109+
def get_size(p: torch.Tensor) -> int:
110+
return (p.nbytes + align_size - 1) // align_size * align_size
111+
112+
named_parameters: dict[str, torch.nn.Parameter] = dict(
113+
self.model.named_parameters()
114+
)
115+
max_tensor_size = max(get_size(p) for p in named_parameters.values())
116+
# use max_tensor_size * 2 as buffer size
117+
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
118+
s = self.zmq_context.socket(zmq.REQ)
119+
s.bind(self.zmq_handle)
120+
handle = reduce_tensor(buffer)
121+
122+
offset = 0
123+
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
124+
named_tensors: list[dict] = []
125+
real_tensors: list[torch.Tensor] = []
126+
for name, p in named_parameters.items():
127+
size = get_size(p)
128+
if offset + size > buffer.numel():
129+
buckets.append((named_tensors, real_tensors))
130+
named_tensors, real_tensors = [], []
131+
offset = 0
132+
# assume tensors are contiguous
133+
named_tensors.append(
134+
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
135+
)
136+
real_tensors.append(p)
137+
offset += size
138+
if named_tensors:
139+
buckets.append((named_tensors, real_tensors))
140+
s.send_pyobj(handle)
141+
s.recv()
142+
for named_tensors, real_tensors in buckets:
143+
offset = 0
144+
for p in real_tensors:
145+
buffer[offset : offset + p.nbytes].data.copy_(
146+
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
147+
)
148+
offset += get_size(p)
149+
torch.cuda.synchronize()
150+
s.send_pyobj(named_tensors)
151+
s.recv()
152+
s.send_pyobj(None)
153+
s.recv()
154+
s.close()
155+
del buffer
156+
gc.collect()
157+
torch.cuda.empty_cache()
103158

104159

105160
# Ray manages four GPUs.
@@ -175,18 +230,22 @@ def get_weight_ipc_handles(self):
175230
# the second inference engine.
176231
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
177232

178-
print("Gather all the IPC handles from the training actors.")
179-
ipc_handles = {}
233+
print("Gather all the ZMQ handles from the training actors.")
234+
zmq_handles = {}
180235
for actor in training_actors:
181-
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
236+
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
237+
238+
print(f"ZMQ handles: {zmq_handles}")
182239

183240
print("Update the weights of the inference engines.")
184-
for llm in inference_engines:
185-
ray.get(
186-
llm.collective_rpc.remote(
187-
"update_weights_from_ipc_handles", args=(ipc_handles,)
188-
)
189-
)
241+
ray.get(
242+
[actor.update_weights.remote() for actor in training_actors]
243+
+ [
244+
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
245+
for llm in inference_engines
246+
]
247+
)
248+
190249
print("Check if the weights are updated.")
191250
for llm in inference_engines:
192251
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))

examples/offline_inference/rlhf_utils.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import gc
4+
from typing import Callable, Optional, TypedDict
5+
36
import torch
7+
import zmq
48

59

610
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
@@ -66,6 +70,27 @@ def check_weights_changed(self):
6670
return weights_updated
6771

6872

73+
def rebuild_ipc(
74+
handle: tuple[Callable, tuple], device_id: Optional[int] = None
75+
) -> torch.Tensor:
76+
func, args = handle
77+
list_args = list(args)
78+
if device_id is not None:
79+
# the key is to change device id to the current device id
80+
# in case two processes have different CUDA_VISIBLE_DEVICES
81+
list_args[6] = device_id
82+
buffer = func(*list_args)
83+
return buffer
84+
85+
86+
class FlattenedTensorMetadata(TypedDict):
87+
name: str
88+
shape: torch.Size
89+
dtype: torch.dtype
90+
# specify the start offset of this tensor in shared ipc_buffer tensor
91+
offset: int
92+
93+
6994
class ColocateWorkerExtension:
7095
"""
7196
The class for vLLM's worker to inherit from, in the colocate setting.
@@ -76,27 +101,62 @@ class ColocateWorkerExtension:
76101
should pass the full qualified name as `worker_extension_cls` argument.
77102
"""
78103

104+
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
105+
from vllm.model_executor.model_loader.utils import process_weights_after_loading
106+
107+
assert self.device is not None
108+
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
109+
self._zmq_ctx = zmq.Context()
110+
socket = self._zmq_ctx.socket(zmq.REP)
111+
socket.connect(zmq_handles[self.report_device_id()])
112+
buffer: Optional[torch.Tensor] = None
113+
while True:
114+
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
115+
socket.recv_pyobj()
116+
)
117+
if payload is None:
118+
# means the update is done
119+
process_weights_after_loading(
120+
self.model_runner.model, self.model_config, self.device
121+
)
122+
torch.cuda.synchronize()
123+
socket.send(b"")
124+
break
125+
if isinstance(payload, tuple):
126+
# an ipc handle that vLLM can use `func, args = handle`
127+
# and `func(*args)` to rebuild GPU tensor.
128+
buffer = rebuild_ipc(payload, self.device.index)
129+
assert buffer.dtype == torch.uint8
130+
socket.send(b"")
131+
continue
132+
assert isinstance(payload, list)
133+
assert buffer is not None
134+
weights = []
135+
for item in payload:
136+
shape = item["shape"]
137+
if isinstance(shape, (list, tuple)):
138+
shape = torch.Size(shape)
139+
assert isinstance(shape, torch.Size)
140+
dtype, offset = item["dtype"], item["offset"]
141+
size = dtype.itemsize * shape.numel()
142+
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
143+
weights.append((item["name"], tensor))
144+
self.model_runner.model.load_weights(weights=weights)
145+
del weights
146+
torch.cuda.synchronize()
147+
socket.send(b"")
148+
149+
socket.close()
150+
del buffer
151+
gc.collect()
152+
torch.cuda.empty_cache()
153+
79154
def report_device_id(self) -> str:
80155
from vllm.platforms import current_platform
81156

82157
self.device_uuid = current_platform.get_device_uuid(self.device.index)
83158
return self.device_uuid
84159

85-
def update_weights_from_ipc_handles(self, ipc_handles):
86-
handles = ipc_handles[self.device_uuid]
87-
device_id = self.device.index
88-
weights = []
89-
for name, handle in handles.items():
90-
func, args = handle
91-
list_args = list(args)
92-
# the key is to change device id to the current device id
93-
# in case two processes have different CUDA_VISIBLE_DEVICES
94-
list_args[6] = device_id
95-
tensor = func(*list_args)
96-
weights.append((name, tensor))
97-
self.model_runner.model.load_weights(weights=weights)
98-
torch.cuda.synchronize()
99-
100160
def check_weights_changed(self):
101161
"""
102162
Check if the weights are updated to 0.

0 commit comments

Comments
 (0)