Skip to content

Commit 5acac93

Browse files
Binyang2014caiomcbrchhwang
authored
Integrate MSCCL++ DSL to torch workload (#620)
Provides two integration ways for MSCCL++ DSL. 1. Integrate with customized communication group 2. Integrate with NCCL API Introduce new Python APIs to make it work: ```python mscclpp.compile # compile dsl to json based execution plan mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size.... ``` Fix #556 --------- Co-authored-by: Caio Rocha <caiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
1 parent 9994f53 commit 5acac93

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1438
-277
lines changed

apps/nccl/src/allreduce.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,6 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
12371237

12381238
size_t scratchBufferSize_;
12391239
std::shared_ptr<char> scratchBuffer_;
1240-
const int nSegmentsForScratchBuffer_ = 2;
12411240
const size_t nvlsBufferSize_ = (1 << 30);
12421241

12431242
std::shared_ptr<uint32_t> deviceFlag_;

apps/nccl/src/nccl.cu

Lines changed: 51 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ static inline int mscclppNcclDlopenInit() {
120120
return dlopenSuccess;
121121
}
122122

123-
static inline void mscclppNcclDlopenFinalize() {
123+
// No need to call this function, handle will be closed at program exit
124+
[[maybe_unused]] static inline void mscclppNcclDlopenFinalize() {
124125
if (mscclppNcclDlHandle) {
125126
dlclose(mscclppNcclDlHandle);
126127
}
@@ -159,17 +160,6 @@ static bool tryLoadNcclSharedLib() {
159160
// Declare the global map to store associations between raw pointer and shared pointer
160161
static std::unordered_map<void*, std::shared_ptr<char>> ptrMap;
161162

162-
struct planKey {
163-
size_t minMessageSize;
164-
size_t maxMessageSize;
165-
bool isInPlace;
166-
};
167-
168-
struct executionPlanInstance {
169-
planKey key;
170-
std::shared_ptr<mscclpp::ExecutionPlan> plan;
171-
};
172-
173163
struct splitCommInfo {
174164
int color;
175165
int key;
@@ -179,23 +169,16 @@ struct splitCommInfo {
179169
struct ncclComm {
180170
std::shared_ptr<mscclpp::Communicator> comm;
181171
std::shared_ptr<mscclpp::Executor> executor;
182-
std::unordered_map<std::string, std::vector<executionPlanInstance>> executionPlans;
183172
std::shared_ptr<mscclpp::AlgorithmCollection> algorithmCollection;
184173
std::shared_ptr<char> scratchBuffer_;
185174
const size_t scratchBufferSize_ = (1 << 27); // 128MB
175+
std::shared_ptr<mscclpp::ExecutionPlanRegistry> planRegistry_;
186176
int nRanksPerNode;
187177
int worldSize;
188178

189179
void* mscclppNcclComm;
190180
};
191181

192-
static std::pair<std::string, executionPlanInstance> loadExecutionPlan(const std::string& filename, int rank) {
193-
std::shared_ptr<mscclpp::ExecutionPlan> plan = std::make_shared<mscclpp::ExecutionPlan>(filename, rank);
194-
std::string collective = plan->collective();
195-
planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()};
196-
return std::make_pair(collective, executionPlanInstance{key, plan});
197-
}
198-
199182
static ncclResult_t executeWithPlan(std::shared_ptr<mscclpp::Executor> executor, int rank, ncclDataType_t datatype,
200183
const void* sendbuff, void* recvbuff, size_t sendBytes, size_t recvBytes,
201184
std::shared_ptr<mscclpp::ExecutionPlan> plan, cudaStream_t stream) {
@@ -352,6 +335,20 @@ static mscclpp::Algorithm algoSelector(
352335
return mscclpp::Algorithm();
353336
}
354337

338+
std::shared_ptr<mscclpp::ExecutionPlanHandle> executionPlanDefaultSelector(
339+
const std::vector<std::shared_ptr<mscclpp::ExecutionPlanHandle>> plans, const mscclpp::ExecutionRequest&) {
340+
if (plans.empty()) {
341+
INFO(MSCCLPP_NCCL, "No execution plans available for selection");
342+
return nullptr;
343+
}
344+
for (auto plan : plans) {
345+
if (plan->tags.find("default") == plan->tags.end()) {
346+
return plan;
347+
}
348+
}
349+
return plans[0];
350+
}
351+
355352
NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
356353
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank %d, world_size=%d", rank, nranks);
357354
if (comm == nullptr) {
@@ -371,29 +368,13 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
371368

372369
commPtr->comm = mscclppComm;
373370
commPtr->scratchBuffer_ = mscclpp::GpuBuffer<char>(commPtr->scratchBufferSize_).memory();
374-
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);
371+
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm, commPtr->scratchBuffer_);
372+
commPtr->planRegistry_ = mscclpp::ExecutionPlanRegistry::getInstance();
373+
375374
commPtr->nRanksPerNode = mscclppComm->bootstrap()->getNranksPerNode();
376375
commPtr->worldSize = mscclppComm->bootstrap()->getNranks();
377-
378-
if (commPtr->worldSize == 1) {
379-
*comm = commPtr;
380-
return ncclSuccess;
381-
}
382-
383-
const std::string& collectiveDir = mscclpp::env()->executionPlanDir;
384-
if (collectiveDir != "") {
385-
if (!std::filesystem::is_directory(collectiveDir)) {
386-
WARN("The value of the environment variable %s is not a directory", collectiveDir.c_str());
387-
return ncclInvalidArgument;
388-
}
389-
for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) {
390-
if (entry.is_regular_file()) {
391-
auto plan = loadExecutionPlan(entry.path(), rank);
392-
commPtr->executionPlans[plan.first].push_back(plan.second);
393-
}
394-
}
395-
}
396-
376+
commPtr->planRegistry_->loadDefaultPlans(rank);
377+
commPtr->planRegistry_->setDefaultSelector(executionPlanDefaultSelector);
397378
mscclpp::AlgorithmCollectionBuilder::getInstance()->setFallbackAlgorithmSelector(algoSelector);
398379
registerCustomizedAlgo();
399380
commPtr->algorithmCollection = mscclpp::AlgorithmCollectionBuilder::getInstance()->build();
@@ -462,12 +443,12 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) {
462443
}
463444
#endif
464445

465-
if (mscclppNcclDlopenSharedLib == true) {
466-
mscclppNcclOps.CommDestroy(*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm));
467-
mscclppNcclDlopenFinalize();
468-
delete static_cast<ncclComm_t*>(comm->mscclppNcclComm);
469-
}
446+
ncclComm_t* mscclppNcclCommPtr = reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm);
470447
delete comm;
448+
if (mscclppNcclCommPtr != nullptr) {
449+
mscclppNcclOps.CommDestroy(*reinterpret_cast<ncclComm_t*>(mscclppNcclCommPtr));
450+
delete static_cast<ncclComm_t*>(mscclppNcclCommPtr);
451+
}
471452
return ncclSuccess;
472453
}
473454

@@ -646,18 +627,13 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
646627
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
647628
}
648629

649-
std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
650-
std::shared_ptr<mscclpp::ExecutionPlan> plan;
651-
bool inPlace = sendbuff == recvbuff;
652-
for (const auto& p : plans) {
653-
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
654-
plan = p.plan;
655-
break;
656-
}
657-
}
658-
659-
if (plan != nullptr) {
660-
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, plan, stream);
630+
static std::unordered_map<std::string, std::vector<uint64_t>> hints{{"root", {static_cast<uint64_t>(root)}}};
631+
hints["root"][0] = static_cast<uint64_t>(root);
632+
auto planHandle = comm->planRegistry_->select("broadcast", comm->comm->bootstrap()->getNranks(),
633+
comm->comm->bootstrap()->getNranksPerNode(),
634+
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, hints);
635+
if (planHandle != nullptr) {
636+
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, planHandle->plan, stream);
661637
}
662638
auto algo = comm->algorithmCollection->selectAlgorithm(
663639
"broadcast", sendbuff, recvbuff, count * ncclTypeSize(datatype), datatype,
@@ -706,18 +682,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
706682
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
707683
}
708684

709-
std::vector<executionPlanInstance>& plans = comm->executionPlans["allreduce"];
710-
std::shared_ptr<mscclpp::ExecutionPlan> plan;
711-
bool inPlace = sendbuff == recvbuff;
712-
for (const auto& p : plans) {
713-
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
714-
plan = p.plan;
715-
break;
716-
}
717-
}
718-
719-
if (plan != nullptr) {
720-
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, plan, stream);
685+
auto planHandler = comm->planRegistry_->select("allreduce", comm->comm->bootstrap()->getNranks(),
686+
comm->comm->bootstrap()->getNranksPerNode(),
687+
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
688+
if (planHandler != nullptr) {
689+
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, planHandler->plan, stream);
721690
}
722691

723692
auto algo = comm->algorithmCollection->selectAlgorithm(
@@ -769,20 +738,12 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
769738
int rank = comm->comm->bootstrap()->getRank();
770739
int nRank = comm->comm->bootstrap()->getNranks();
771740

772-
std::vector<executionPlanInstance>& plans = comm->executionPlans["reducescatter"];
773-
std::shared_ptr<mscclpp::ExecutionPlan> plan;
774-
void* basePtr = (char*)sendbuff + rank * bytes;
775-
bool inPlace = basePtr == recvbuff;
776-
const size_t totalBytes = bytes * nRank;
777-
for (const auto& p : plans) {
778-
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
779-
plan = p.plan;
780-
break;
781-
}
782-
}
783-
784-
if (plan != nullptr) {
785-
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, totalBytes, bytes, plan, stream);
741+
auto planHandle = comm->planRegistry_->select("reducescatter", comm->comm->bootstrap()->getNranks(),
742+
comm->comm->bootstrap()->getNranksPerNode(),
743+
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
744+
if (planHandle != nullptr) {
745+
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes * nRank, bytes, planHandle->plan,
746+
stream);
786747
}
787748

788749
if (mscclppNcclDlopenSharedLib == true) {
@@ -821,20 +782,12 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
821782
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
822783
}
823784

824-
std::vector<executionPlanInstance>& plans = comm->executionPlans["allgather"];
825-
std::shared_ptr<mscclpp::ExecutionPlan> plan;
826-
void* basePtr = (char*)sendbuff - rank * bytes;
827-
bool inPlace = basePtr == recvbuff;
828-
const size_t totalBytes = bytes * nRank;
829-
for (const auto& p : plans) {
830-
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
831-
plan = p.plan;
832-
break;
833-
}
834-
}
835-
836-
if (plan != nullptr) {
837-
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, totalBytes, plan, stream);
785+
auto planHandle = comm->planRegistry_->select("allgather", comm->comm->bootstrap()->getNranks(),
786+
comm->comm->bootstrap()->getNranksPerNode(),
787+
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
788+
if (planHandle != nullptr) {
789+
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes * nRank, planHandle->plan,
790+
stream);
838791
}
839792

840793
auto algo = comm->algorithmCollection->selectAlgorithm(

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"show-inheritance": True,
5050
}
5151
# only mock the C-extension when using the source tree
52-
autodoc_mock_imports = ["mscclpp._version", "mscclpp._mscclpp", "cupy", "mpi4py", "numpy", "sortedcontainers"]
52+
autodoc_mock_imports = ["mscclpp._version", "mscclpp._mscclpp", "blake3", "cupy", "mpi4py", "numpy", "sortedcontainers"]
5353
autodoc_typehints = "description"
5454
napoleon_google_docstring = True
5555
napoleon_numpy_docstring = True
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# MSCCL++ DSL Integration Guide
2+
3+
MSCCL++ DSL (domain-specific language) enables concise expression of collective algorithms as Python functions.
4+
MSCCL++ offers pythonic utilities to author, JIT-compile, register, and select execution plans. This guide walks through two integration paths: a customized MSCCL++ communicator and NCCL interposition that accelerates existing PyTorch `backend="nccl"` workloads.
5+
6+
## Initial Setup
7+
8+
Run the following from the repository root after completing the basic project setup:
9+
10+
1. Install Python dependencies.
11+
```bash
12+
pip install -r ./python/<requirements_file>
13+
```
14+
Replace `<requirements_file>` with the file that matches your environment (e.g., `requirements_cuda11.txt`, `requirements_cuda12.txt`, or `requirements_rocm6.txt`).
15+
16+
2. Install the module and generate default algorithm plans.
17+
```bash
18+
pip install . && python3 -m mscclpp --install
19+
```
20+
21+
## Integration Options
22+
23+
MSCCL++ DSL integrates into your training or inference workload in two ways:
24+
1. **Custom MSCCL++ Communicator** — directly manage an MSCCL++ communicator and launch collectives with the MSCCL++ executor.
25+
2. **NCCL Interposition** — keep using `backend="nccl"`; MSCCL++ intercepts NCCL calls at runtime for drop-in acceleration.
26+
27+
Both paths follow the same high-level flow:
28+
1. Author (or reuse) a collective algorithm with the MSCCL++ DSL.
29+
2. Compile it into an execution plan.
30+
3. Register the plan with the MSCCL++ runtime.
31+
4. Configure a selector to choose the plan for each collective call.
32+
33+
Below we show an AllReduce example and then detail each integration option.
34+
35+
### Example: AllReduce in the MSCCL++ DSL
36+
The snippet defines an AllReduce that uses NVLS for intra-node reduce-scatter followed by broadcast.
37+
```python
38+
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
39+
gpu_size = spec.world_size
40+
with CollectiveProgram(
41+
spec.name,
42+
spec.collective,
43+
gpu_size,
44+
instances=8,
45+
protocol=spec.protocol,
46+
num_threads_per_block=spec.num_threads_per_block,
47+
min_message_size=spec.min_message_size,
48+
max_message_size=spec.max_message_size,
49+
) as program:
50+
# Creating Channels
51+
nvls_chan = SwitchChannel(rank_list=[gpu for gpu in range(gpu_size)], buffer_type=BufferType.input)
52+
channels = {}
53+
for gpu in range(gpu_size):
54+
for peer in range(gpu_size):
55+
if peer != gpu:
56+
channels[(peer, gpu)] = MemoryChannel(peer, gpu)
57+
58+
# Synchronization to Ensure all the Gpus are Ready
59+
for gpu in range(gpu_size):
60+
src_rank = gpu
61+
for peer in range(gpu_size):
62+
if peer != src_rank:
63+
dst_rank = peer
64+
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True)
65+
for peer in range(gpu_size):
66+
if peer != src_rank:
67+
dst_rank = peer
68+
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True, data_sync=SyncType.after)
69+
# Reducing and Storing the data
70+
for gpu in range(gpu_size):
71+
buffer_offset = gpu
72+
rank = Rank(gpu)
73+
input_buffer = rank.get_input_buffer()
74+
nvls_chan.at_rank(gpu).reduce(
75+
buffer_offset=buffer_offset, size=1, dst_chunk=input_buffer[gpu : gpu + 1], tb=0
76+
)
77+
nvls_chan.at_rank(gpu).broadcast(
78+
src_chunk=input_buffer[gpu : gpu + 1], buffer_offset=buffer_offset, size=1, tb=0
79+
)
80+
# Synchronization to Ensure the Gpus finished
81+
for gpu in range(gpu_size):
82+
src_rank = gpu
83+
for peer in range(gpu_size):
84+
if peer != src_rank:
85+
dst_rank = peer
86+
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True, data_sync=SyncType.before)
87+
for peer in range(gpu_size):
88+
if peer != src_rank:
89+
dst_rank = peer
90+
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True)
91+
92+
return program
93+
```
94+
95+
### Integrate with MSCCL++ customized communicator
96+
Use when you want a PyTorch‑compatible interface with fine‑grained control. You manage the communicator, compile/register DSL plans, and invoke collectives via a thin wrapper. The example below shows an AllReduce built on the MSCCL++ communicator and executor.
97+
Example source directory:
98+
```
99+
examples/torch-integration
100+
```
101+
Key file: `customized_comm.py`.
102+
103+
104+
#### Launch (single node)
105+
```bash
106+
MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm.py
107+
```
108+
109+
### Integrate via NCCL Interposition
110+
Keep your script as‑is: init PyTorch with backend="nccl"; MSCCL++ intercepts NCCL calls for drop‑in acceleration.
111+
Example source directory:
112+
```
113+
examples/torch-integration
114+
```
115+
Key file: `dsl_with_nccl_api.py`.
116+
117+
#### Launch with interposition
118+
To run with NCCL interposition, you preload the MSCCL++ shim so it transparently intercepts NCCL calls made by PyTorch’s nccl backend.
119+
```bash
120+
LD_PRELOAD=<MSCCLPP_REPO>/build/apps/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
121+
```
122+
## Notices:
123+
- When using NCCL interposition, the algorithm selection order is:
124+
1. Check for registered DSL plans matching the collective call.
125+
2. Check for a customized kernel implementation if no DSL plan fits.
126+
3. Fall back to the default NCCL implementation (set `MSCCLPP_NCCL_LIB_PATH` to the original NCCL library).

docs/programming_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ This section provides advanced topics and best practices for using MSCCL++. It i
1313
guide/cpp-examples
1414
guide/mscclpp-dsl
1515
guide/customized-algorithm-with-nccl-api
16+
guide/mscclpp-dsl-integration

0 commit comments

Comments
 (0)