|
1 | | -# Adapted from |
| 1 | +# Adapted from |
2 | 2 | # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/pynccl_wrapper.py |
3 | 3 | # of the vllm-project/vllm GitHub repository. |
4 | 4 | # |
@@ -146,46 +146,43 @@ class NCCLLibrary: |
146 | 146 | # const char* ncclGetErrorString(ncclResult_t result) |
147 | 147 | Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), |
148 | 148 | # ncclResult_t ncclGetVersion(int *version); |
149 | | - Function("ncclGetVersion", ncclResult_t, |
150 | | - [ctypes.POINTER(ctypes.c_int)]), |
| 149 | + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), |
151 | 150 | # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); |
152 | | - Function("ncclGetUniqueId", ncclResult_t, |
153 | | - [ctypes.POINTER(ncclUniqueId)]), |
| 151 | + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), |
154 | 152 | # ncclResult_t ncclCommInitRank( |
155 | 153 | # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); |
156 | 154 | # note that ncclComm_t is a pointer type, so the first argument |
157 | 155 | # is a pointer to a pointer |
158 | | - Function("ncclCommInitRank", ncclResult_t, [ |
159 | | - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, |
160 | | - ctypes.c_int |
161 | | - ]), |
| 156 | + Function( |
| 157 | + "ncclCommInitRank", ncclResult_t, [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int] |
| 158 | + ), |
162 | 159 | # ncclResult_t ncclAllReduce( |
163 | 160 | # const void* sendbuff, void* recvbuff, size_t count, |
164 | 161 | # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, |
165 | 162 | # cudaStream_t stream); |
166 | 163 | # note that cudaStream_t is a pointer type, so the last argument |
167 | 164 | # is a pointer |
168 | | - Function("ncclAllReduce", ncclResult_t, [ |
169 | | - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, |
170 | | - ncclRedOp_t, ncclComm_t, cudaStream_t |
171 | | - ]), |
172 | | - |
| 165 | + Function( |
| 166 | + "ncclAllReduce", |
| 167 | + ncclResult_t, |
| 168 | + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], |
| 169 | + ), |
173 | 170 | # ncclResult_t ncclSend( |
174 | 171 | # const void* sendbuff, size_t count, ncclDataType_t datatype, |
175 | 172 | # int dest, ncclComm_t comm, cudaStream_t stream); |
176 | | - Function("ncclSend", ncclResult_t, [ |
177 | | - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, |
178 | | - ncclComm_t, cudaStream_t |
179 | | - ]), |
180 | | - |
| 173 | + Function( |
| 174 | + "ncclSend", |
| 175 | + ncclResult_t, |
| 176 | + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], |
| 177 | + ), |
181 | 178 | # ncclResult_t ncclRecv( |
182 | 179 | # void* recvbuff, size_t count, ncclDataType_t datatype, |
183 | 180 | # int src, ncclComm_t comm, cudaStream_t stream); |
184 | | - Function("ncclRecv", ncclResult_t, [ |
185 | | - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, |
186 | | - ncclComm_t, cudaStream_t |
187 | | - ]), |
188 | | - |
| 181 | + Function( |
| 182 | + "ncclRecv", |
| 183 | + ncclResult_t, |
| 184 | + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], |
| 185 | + ), |
189 | 186 | # be cautious! this is a collective call, it will block until all |
190 | 187 | # processes in the communicator have called this function. |
191 | 188 | # because Python object destruction can happen in random order, |
@@ -219,8 +216,10 @@ def __init__(self, so_file: Optional[str] = None): |
219 | 216 | "or it does not support the current platform %s." |
220 | 217 | "If you already have the library, please set the " |
221 | 218 | "environment variable VLLM_NCCL_SO_PATH" |
222 | | - " to point to the correct nccl library path.", so_file, |
223 | | - platform.platform()) |
| 219 | + " to point to the correct nccl library path.", |
| 220 | + so_file, |
| 221 | + platform.platform(), |
| 222 | + ) |
224 | 223 | raise e |
225 | 224 |
|
226 | 225 | if so_file not in NCCLLibrary.path_to_dict_mapping: |
@@ -253,45 +252,51 @@ def ncclGetVersion(self) -> str: |
253 | 252 |
|
254 | 253 | def ncclGetUniqueId(self) -> ncclUniqueId: |
255 | 254 | unique_id = ncclUniqueId() |
256 | | - self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( |
257 | | - ctypes.byref(unique_id))) |
| 255 | + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) |
258 | 256 | return unique_id |
259 | 257 |
|
260 | | - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, |
261 | | - rank: int) -> ncclComm_t: |
| 258 | + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: |
262 | 259 | comm = ncclComm_t() |
263 | | - self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), |
264 | | - world_size, unique_id, |
265 | | - rank)) |
| 260 | + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank)) |
266 | 261 | return comm |
267 | 262 |
|
268 | | - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, |
269 | | - count: int, datatype: int, op: int, comm: ncclComm_t, |
270 | | - stream: cudaStream_t) -> None: |
| 263 | + def ncclAllReduce( |
| 264 | + self, |
| 265 | + sendbuff: buffer_type, |
| 266 | + recvbuff: buffer_type, |
| 267 | + count: int, |
| 268 | + datatype: int, |
| 269 | + op: int, |
| 270 | + comm: ncclComm_t, |
| 271 | + stream: cudaStream_t, |
| 272 | + ) -> None: |
271 | 273 | # `datatype` actually should be `ncclDataType_t` |
272 | 274 | # and `op` should be `ncclRedOp_t` |
273 | 275 | # both are aliases of `ctypes.c_int` |
274 | 276 | # when we pass int to a function, it will be converted to `ctypes.c_int` |
275 | 277 | # by ctypes automatically |
276 | | - self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, |
277 | | - datatype, op, comm, |
278 | | - stream)) |
| 278 | + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) |
279 | 279 |
|
280 | | - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, |
281 | | - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: |
282 | | - self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, |
283 | | - dest, comm, stream)) |
| 280 | + def ncclSend( |
| 281 | + self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t |
| 282 | + ) -> None: |
| 283 | + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)) |
284 | 284 |
|
285 | | - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, |
286 | | - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: |
287 | | - self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, |
288 | | - comm, stream)) |
| 285 | + def ncclRecv( |
| 286 | + self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t |
| 287 | + ) -> None: |
| 288 | + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) |
289 | 289 |
|
290 | 290 | def ncclCommDestroy(self, comm: ncclComm_t) -> None: |
291 | 291 | self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) |
292 | 292 |
|
293 | 293 |
|
294 | 294 | __all__ = [ |
295 | | - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", |
296 | | - "ncclComm_t", "cudaStream_t", "buffer_type" |
| 295 | + "NCCLLibrary", |
| 296 | + "ncclDataTypeEnum", |
| 297 | + "ncclRedOpTypeEnum", |
| 298 | + "ncclUniqueId", |
| 299 | + "ncclComm_t", |
| 300 | + "cudaStream_t", |
| 301 | + "buffer_type", |
297 | 302 | ] |
0 commit comments