|
28 | 28 | https://docs.ray.io/en/latest/placement-groups.html |
29 | 29 | """ |
30 | 30 |
|
| 31 | +import gc |
31 | 32 | import os |
32 | 33 |
|
33 | 34 | import ray |
34 | 35 | import torch |
| 36 | +import zmq |
35 | 37 | from ray.util.placement_group import placement_group |
36 | 38 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 39 | +from torch.multiprocessing.reductions import reduce_tensor |
37 | 40 |
|
38 | 41 | from vllm import LLM |
39 | 42 |
|
@@ -86,20 +89,72 @@ def __init__(self): |
86 | 89 | from vllm.platforms import current_platform |
87 | 90 |
|
88 | 91 | self.device_uuid = current_platform.get_device_uuid(0) |
| 92 | + self.zmq_context = zmq.Context() |
| 93 | + self.zmq_address_counter = 0 |
| 94 | + self.zmq_handle = None |
89 | 95 |
|
90 | 96 | def report_device_id(self) -> str: |
91 | 97 | return self.device_uuid |
92 | 98 |
|
93 | | - def get_weight_ipc_handles(self): |
94 | | - from torch.multiprocessing.reductions import reduce_tensor |
| 99 | + def get_zmq_handles(self) -> dict[str, str]: |
| 100 | + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" |
| 101 | + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" |
| 102 | + self.zmq_address_counter += 1 |
| 103 | + return {self.device_uuid: self.zmq_handle} |
95 | 104 |
|
96 | | - data = {} |
97 | | - for name, p in self.model.named_parameters(): |
98 | | - # A training actor might hold only a subset of the weights and may |
99 | | - # need to gather weights from other actors. For demonstration |
100 | | - # purposes, each training actor owns the full weight set. |
101 | | - data[name] = reduce_tensor(p.detach()) |
102 | | - return {self.device_uuid: data} |
| 105 | + def update_weights(self): |
| 106 | + # align size to avoid misaligned address |
| 107 | + align_size = 256 |
| 108 | + |
| 109 | + def get_size(p: torch.Tensor) -> int: |
| 110 | + return (p.nbytes + align_size - 1) // align_size * align_size |
| 111 | + |
| 112 | + named_parameters: dict[str, torch.nn.Parameter] = dict( |
| 113 | + self.model.named_parameters() |
| 114 | + ) |
| 115 | + max_tensor_size = max(get_size(p) for p in named_parameters.values()) |
| 116 | + # use max_tensor_size * 2 as buffer size |
| 117 | + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") |
| 118 | + s = self.zmq_context.socket(zmq.REQ) |
| 119 | + s.bind(self.zmq_handle) |
| 120 | + handle = reduce_tensor(buffer) |
| 121 | + |
| 122 | + offset = 0 |
| 123 | + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] |
| 124 | + named_tensors: list[dict] = [] |
| 125 | + real_tensors: list[torch.Tensor] = [] |
| 126 | + for name, p in named_parameters.items(): |
| 127 | + size = get_size(p) |
| 128 | + if offset + size > buffer.numel(): |
| 129 | + buckets.append((named_tensors, real_tensors)) |
| 130 | + named_tensors, real_tensors = [], [] |
| 131 | + offset = 0 |
| 132 | + # assume tensors are contiguous |
| 133 | + named_tensors.append( |
| 134 | + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} |
| 135 | + ) |
| 136 | + real_tensors.append(p) |
| 137 | + offset += size |
| 138 | + if named_tensors: |
| 139 | + buckets.append((named_tensors, real_tensors)) |
| 140 | + s.send_pyobj(handle) |
| 141 | + s.recv() |
| 142 | + for named_tensors, real_tensors in buckets: |
| 143 | + offset = 0 |
| 144 | + for p in real_tensors: |
| 145 | + buffer[offset : offset + p.nbytes].data.copy_( |
| 146 | + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True |
| 147 | + ) |
| 148 | + offset += get_size(p) |
| 149 | + torch.cuda.synchronize() |
| 150 | + s.send_pyobj(named_tensors) |
| 151 | + s.recv() |
| 152 | + s.send_pyobj(None) |
| 153 | + s.recv() |
| 154 | + s.close() |
| 155 | + del buffer |
| 156 | + gc.collect() |
| 157 | + torch.cuda.empty_cache() |
103 | 158 |
|
104 | 159 |
|
105 | 160 | # Ray manages four GPUs. |
@@ -175,18 +230,22 @@ def get_weight_ipc_handles(self): |
175 | 230 | # the second inference engine. |
176 | 231 | assert training_actor_device_ids[2:] == inference_engine_device_ids[1] |
177 | 232 |
|
178 | | -print("Gather all the IPC handles from the training actors.") |
179 | | -ipc_handles = {} |
| 233 | +print("Gather all the ZMQ handles from the training actors.") |
| 234 | +zmq_handles = {} |
180 | 235 | for actor in training_actors: |
181 | | - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) |
| 236 | + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) |
| 237 | + |
| 238 | +print(f"ZMQ handles: {zmq_handles}") |
182 | 239 |
|
183 | 240 | print("Update the weights of the inference engines.") |
184 | | -for llm in inference_engines: |
185 | | - ray.get( |
186 | | - llm.collective_rpc.remote( |
187 | | - "update_weights_from_ipc_handles", args=(ipc_handles,) |
188 | | - ) |
189 | | - ) |
| 241 | +ray.get( |
| 242 | + [actor.update_weights.remote() for actor in training_actors] |
| 243 | + + [ |
| 244 | + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) |
| 245 | + for llm in inference_engines |
| 246 | + ] |
| 247 | +) |
| 248 | + |
190 | 249 | print("Check if the weights are updated.") |
191 | 250 | for llm in inference_engines: |
192 | 251 | assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) |
0 commit comments