Skip to content

Commit 6b2bef1

Browse files
fduwjjpytorchmergebot
authored andcommitted
[c10d] Prototype of group_split for dist2 work (pytorch#157716)
This is to implement group_split as proposed in [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89) Pull Request resolved: pytorch#157716 Approved by: https://github.com/d4l3k
1 parent 1e4d8b5 commit 6b2bef1

File tree

14 files changed

+246
-7
lines changed

14 files changed

+246
-7
lines changed

test/cpp/c10d/ProcessGroupNCCLTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class NCCLTestBase {
2828

2929
NCCLTestBase(NCCLTestBase&& other) noexcept = default;
3030

31-
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
31+
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
3232
return pg_;
3333
}
3434

@@ -39,7 +39,7 @@ class NCCLTestBase {
3939
void initialize(
4040
int rank,
4141
size_t size,
42-
std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
42+
std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from =
4343
std::nullopt) {
4444
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
4545

@@ -52,13 +52,13 @@ class NCCLTestBase {
5252
opts->split_color = ++color_;
5353
}
5454
#endif
55-
pg_ = std::make_unique<::c10d::ProcessGroupNCCL>(
55+
pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
5656
store_, rank, size, std::move(opts));
5757
}
5858

5959
protected:
6060
std::string path_;
61-
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
61+
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_;
6262
std::chrono::milliseconds pgTimeout_;
6363
::c10::intrusive_ptr<::c10d::Store> store_;
6464
int color_{1};

test/distributed/test_dist2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ def test_alltoall_base(self) -> None:
201201
out_range = out[i * 10 : (i + 1) * 10]
202202
self.assertEqual(out_range, torch.full_like(out_range, i + 1))
203203

204+
def test_group_split(self) -> None:
205+
group = self.new_group()
206+
subgroup = group.split_group([0], timeout=timedelta(seconds=30))
207+
if self.rank == 0:
208+
assert subgroup is not None
209+
self.assertEqual(subgroup.size(), 1)
210+
backend = subgroup._get_backend(self.device)
211+
self.assertEqual(backend.options._timeout, timedelta(seconds=30))
212+
else:
213+
self.assertEqual(subgroup, None)
214+
204215

205216
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
206217
device = torch.device("cpu")

torch/_C/_distributed_c10d.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,13 @@ class ProcessGroup:
350350
) -> None: ...
351351
def rank(self) -> int: ...
352352
def size(self) -> int: ...
353+
def split_group(
354+
self,
355+
new_ranks: list[int],
356+
timeout: Optional[timedelta] = None,
357+
pg_options: Optional[Backend.Options] = None,
358+
group_desc: Optional[str] = None,
359+
) -> Optional[ProcessGroup]: ...
353360
def abort(self) -> None: ...
354361
def set_timeout(self, timeout: timedelta) -> None: ...
355362
def shutdown(self) -> None: ...

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
4646
// backend name
4747
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
4848
const std::string backend;
49+
std::string group_name;
4950
};
5051

5152
explicit Backend(int rank, int size);
@@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
105106
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
106107
}
107108

109+
// Subclasses must override this method to return the backend name
110+
virtual c10::intrusive_ptr<Options> getBackendOptions() {
111+
TORCH_CHECK(
112+
false,
113+
c10::str(
114+
"Backend ", getBackendName(), " does not implement endCoalescing"));
115+
}
116+
108117
virtual c10::intrusive_ptr<Work> broadcast(
109118
std::vector<at::Tensor>& /* tensors */,
110119
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
@@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder {
379388
" is missing implementation of enableCollectivesTiming.");
380389
}
381390

391+
virtual c10::intrusive_ptr<Backend> splitBackend(
392+
const std::vector<int>& ranks,
393+
const c10::intrusive_ptr<Options> opts) {
394+
TORCH_CHECK(
395+
false,
396+
"Backend ",
397+
getBackendName(),
398+
" is missing implementation of splitBackend.");
399+
}
400+
382401
bool hasHooks() const {
383402
return onCompletionHook_ != nullptr;
384403
}

torch/csrc/distributed/c10d/NCCLUtils.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,27 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
573573
return hash;
574574
}
575575

576+
// NCCL uses Non-negative int to represent in-group according to API
577+
// requirement. We take a list of ranks and generate a hash value based on the
578+
// list and ensure its range of 32-bit int.
579+
int genNcclSplitColor(const std::vector<int>& ranks) {
580+
// Combine the hash values using a simple reducer (std::hash + fold)
581+
std::size_t combined_hash = std::accumulate(
582+
ranks.begin(),
583+
ranks.end(),
584+
std::size_t(0),
585+
[](std::size_t acc, int rank) {
586+
return acc ^
587+
(std::hash<int>{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2));
588+
});
589+
590+
// max positive value of int32_t
591+
constexpr int32_t max_c_int = std::numeric_limits<int32_t>::max();
592+
int color = static_cast<int>(
593+
std::abs(static_cast<int64_t>(combined_hash)) % max_c_int);
594+
return color;
595+
}
596+
576597
// Default value: 30 minutes
577598
int nccl_nonblocking_timeout() {
578599
static int timeout = -2; // -2 means not initialized

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ static std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
231231
};
232232

233233
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
234+
TORCH_API int genNcclSplitColor(const std::vector<int>& ranks);
234235
TORCH_API std::string getNcclVersion();
235236
TORCH_API std::tuple<int, int, int> getNcclVersionTuple();
236237
TORCH_API int getNcclVersionNumber();

torch/csrc/distributed/c10d/ProcessGroup.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <c10/util/Logging.h>
66
#include <fmt/format.h>
7+
#include <fmt/ranges.h>
78
#include <string_view>
89

910
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
@@ -158,6 +159,63 @@ void ProcessGroup::release_resources() {
158159
backendTypeToBackend_.clear();
159160
}
160161

162+
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
163+
const std::vector<int>& ranks,
164+
const std::optional<std::chrono::milliseconds> timeout,
165+
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
166+
const std::optional<std::string>& desc) {
167+
TORCH_CHECK(
168+
ranks.size() > 0,
169+
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
170+
TORCH_CHECK(
171+
ranks.size() < static_cast<size_t>(size_),
172+
"the split group's size should be less than the world_size set by init_process_group");
173+
std::set<int> ranks_set(ranks.begin(), ranks.end());
174+
TORCH_CHECK(
175+
ranks_set.size() == ranks.size(),
176+
"Split ranks should not have duplicates. Please provide a list of unique ranks to split the group.");
177+
std::vector<int> sorted_ranks = ranks;
178+
std::sort(sorted_ranks.begin(), sorted_ranks.end());
179+
c10::intrusive_ptr<ProcessGroup> newGroup;
180+
// TODO: Figure out a better way for split group name.
181+
std::string groupName =
182+
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
183+
for (const auto& pair : deviceTypeToBackendType_) {
184+
c10::DeviceType deviceType = pair.first;
185+
BackendType backendType = pair.second;
186+
187+
auto parentBackend = getBackend(deviceType);
188+
auto backendOpts =
189+
opts.has_value() ? opts.value() : parentBackend->getBackendOptions();
190+
backendOpts->group_name = groupName;
191+
backendOpts->timeout =
192+
timeout.has_value() ? timeout.value() : backendOpts->timeout;
193+
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
194+
if (splitBackend == nullptr) {
195+
continue;
196+
}
197+
198+
// TODO: Figure out a better way for split group desc.
199+
// TODO: We can add a new field in Backend::Options to specify the group
200+
// desc
201+
std::string groupDesc = desc.has_value()
202+
? desc.value()
203+
: c10::str(getGroupDesc(), ":split:", incrementSplitCount());
204+
splitBackend->setGroupDesc(groupDesc);
205+
206+
if (!newGroup) {
207+
newGroup = c10::make_intrusive<ProcessGroup>(
208+
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
209+
newGroup->setDefaultBackend(backendType_);
210+
newGroup->setGroupName(groupName);
211+
newGroup->setGroupDesc(groupDesc);
212+
}
213+
newGroup->setBackend(deviceType, backendType, splitBackend);
214+
}
215+
216+
return newGroup;
217+
}
218+
161219
} // namespace c10d
162220

163221
namespace {

torch/csrc/distributed/c10d/ProcessGroup.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
170170
}
171171
}
172172

173+
int64_t incrementSplitCount() {
174+
return splitCounter_++;
175+
}
176+
173177
virtual void startCoalescing(c10::DeviceType deviceType) {
174178
// only nccl has implemented startCoalescing so only execute for nccl
175179
// backends
@@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
955959
bound_device_id_ = device;
956960
}
957961

962+
// This creates a new subgroup using the specified ranks.
963+
// The current rank must be included in the list of new_ranks.
964+
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
965+
const std::vector<int>& ranks,
966+
const std::optional<std::chrono::milliseconds> timeout,
967+
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
968+
const std::optional<std::string>& groupDesc);
969+
958970
protected:
959971
// Implementations of this interface need to call this to setup
960972
// appropriate logging etc.
@@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
968980
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
969981
BackendType backendType_;
970982
std::string pg_desc_;
983+
int64_t splitCounter_;
971984

972985
// Debug level setting. It is parsed once when ProcessGroup is constructed and
973986
// remains the same across use of this process group.

torch/csrc/distributed/c10d/ProcessGroupGloo.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,35 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
697697
return options_->global_ranks_in_group;
698698
}
699699

700+
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
701+
const std::vector<int>& ranks,
702+
const c10::intrusive_ptr<Backend::Options> opts) {
703+
auto it = std::find(ranks.begin(), ranks.end(), rank_);
704+
int groupRank;
705+
if (it == ranks.end()) {
706+
return nullptr;
707+
} else {
708+
groupRank = std::distance(ranks.begin(), it);
709+
}
710+
711+
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
712+
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
713+
714+
// TODO: we need to get rid of globalRanksInGroup eventually.
715+
std::vector<uint64_t> globalRanksInGroup;
716+
for (auto rank : ranks) {
717+
globalRanksInGroup.emplace_back(groupRanks()[rank]);
718+
}
719+
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
720+
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
721+
TORCH_CHECK(
722+
store != nullptr,
723+
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
724+
auto pg = c10::make_intrusive<ProcessGroupGloo>(
725+
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
726+
return c10::static_intrusive_pointer_cast<Backend>(pg);
727+
}
728+
700729
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
701730
std::unique_lock<std::mutex> lock(workMutex_);
702731
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);

torch/csrc/distributed/c10d/ProcessGroupGloo.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
188188
}
189189
#endif
190190

191+
const c10::intrusive_ptr<::c10d::Store>& _getStore() const {
192+
return store_;
193+
}
194+
191195
protected:
192196
c10::intrusive_ptr<::c10d::Store> store_;
193197
};
@@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
252256
}
253257

254258
std::vector<uint64_t> global_ranks_in_group;
255-
std::string group_name;
256259
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
257260
int threads;
258261
};
@@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend {
301304
}
302305
}
303306

307+
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
308+
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
309+
}
310+
311+
c10::intrusive_ptr<Backend> splitBackend(
312+
const std::vector<int>& ranks,
313+
const c10::intrusive_ptr<Backend::Options> opts) override;
314+
304315
const std::vector<uint64_t>& groupRanks() const;
305316

306317
c10::intrusive_ptr<Work> broadcast(

0 commit comments

Comments
 (0)