-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
Summary
This RFC proposes adding XCCL (Intel GPU Collective Communications Library) backend support to TorchComms, enabling efficient distributed communication on Intel GPUs. This backend leverages Intel's oneCCL library and integrates with PyTorch's XPU device support via SYCL runtime.
Motivation
Currently, TorchComms supports NCCL (NVIDIA), RCCL (AMD), and Gloo (CPU) backends, but lacks native support for Intel GPUs. Since PyTorch provides distributed support for Intel GPUs through the xccl backend, we plan to integrate XCCL into TorchComms as well.
Design Overview
Architecture
The XCCL backend follows the same architectural pattern as existing backends (NCCL/RCCL):
Python API (torchcomms)
↓
TorchCommBackend Interface
↓
TorchCommXCCL Implementation
↓
├── XcclApi (oneCCL wrapper)
├── XpuApi (SYCL/XPU abstraction)
└── TorchCommXCCL (Collectives implementation)
Key Components
Core Files
| File | Purpose |
|---|---|
TorchCommXCCL.cpp/.hpp |
Collectives implementation |
TorchWorkXCCL.cpp/.hpp |
Asynchronous work tracking |
TorchWorkXCCLQueue.cpp |
Work queue management |
TorchCommXCCLBootstrap.cpp/.hpp |
Communicator initialization |
TorchCommXCCLUtils.cpp/.hpp |
Utility functions (type conversion) |
TorchCommXCCLPy.cpp |
Python bindings |
API Abstraction Layers
| File | Purpose |
|---|---|
XcclApi.cpp/.hpp |
oneCCL API wrapper |
device/XpuApi.cpp/.hpp |
SYCL/XPU runtime abstraction |
Build System
| File | Purpose |
|---|---|
xccl/CMakeLists.txt |
CMake configuration for XCCL backend |
setup.py |
Python package build (USE_XCCL flag) |
API Surface
The XCCL backend implements all standard TorchComm operations:
Point-to-Point Operations
send()/recv()batch_op_issue()(batched P2P)
Collective Operations
broadcast()all_reduce()reduce()all_gather()/all_gather_single()reduce_scatter()/reduce_scatter_single()all_to_all()/all_to_all_single()/all_to_all_v_single()scatter()/gather()barrier()
Advanced Operations
split()- Create sub-communicators
Build from Source
# Install Intel oneAPI baskit
# Enable XCCL backend in TorchComms source code build
USE_XCCL=ON python setup.py install
# Or for development
USE_XCCL=ON pip install -e .Usage Examples
Basic AllReduce
import torch
import torchcomms
from torchcomms import new_comm
# Initialize communicator
device = torch.device('xpu')
comm = new_comm("xccl", device, name="main_comm")
rank = comm.get_rank()
current_device = torch.device(f"xpu:{rank}")
# Perform collective operation
tensor = torch.randn(1024, device=current_device)
work = comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=True)
work.wait()
# Cleanup
comm.finalize()PR plan
✅ Core Infrastructure
-
TorchCommXCCLclass with full backend interface implementation -
TorchWorkXCCLfor asynchronous operation tracking -
TorchWorkXCCLQueuefor work management - Python bindings (
_comms_xcclmodule) - CMake build system integration
✅ API Abstraction
-
XcclApiwrapper for all oneCCL operations -
XpuApifor SYCL/XPU runtime operations
✅ Collective Operations
-
AllReducecollectives as an entry point - other collectives
siju-samuel
Metadata
Metadata
Assignees
Labels
No labels