|
4 | 4 |
|
5 | 5 | #include <c10/util/Logging.h> |
6 | 6 | #include <fmt/format.h> |
| 7 | +#include <fmt/ranges.h> |
7 | 8 | #include <string_view> |
8 | 9 |
|
9 | 10 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
@@ -158,6 +159,63 @@ void ProcessGroup::release_resources() { |
158 | 159 | backendTypeToBackend_.clear(); |
159 | 160 | } |
160 | 161 |
|
| 162 | +c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup( |
| 163 | + const std::vector<int>& ranks, |
| 164 | + const std::optional<std::chrono::milliseconds> timeout, |
| 165 | + const std::optional<c10::intrusive_ptr<Backend::Options>> opts, |
| 166 | + const std::optional<std::string>& desc) { |
| 167 | + TORCH_CHECK( |
| 168 | + ranks.size() > 0, |
| 169 | + "Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group."); |
| 170 | + TORCH_CHECK( |
| 171 | + ranks.size() < static_cast<size_t>(size_), |
| 172 | + "the split group's size should be less than the world_size set by init_process_group"); |
| 173 | + std::set<int> ranks_set(ranks.begin(), ranks.end()); |
| 174 | + TORCH_CHECK( |
| 175 | + ranks_set.size() == ranks.size(), |
| 176 | + "Split ranks should not have duplicates. Please provide a list of unique ranks to split the group."); |
| 177 | + std::vector<int> sorted_ranks = ranks; |
| 178 | + std::sort(sorted_ranks.begin(), sorted_ranks.end()); |
| 179 | + c10::intrusive_ptr<ProcessGroup> newGroup; |
| 180 | + // TODO: Figure out a better way for split group name. |
| 181 | + std::string groupName = |
| 182 | + c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks)); |
| 183 | + for (const auto& pair : deviceTypeToBackendType_) { |
| 184 | + c10::DeviceType deviceType = pair.first; |
| 185 | + BackendType backendType = pair.second; |
| 186 | + |
| 187 | + auto parentBackend = getBackend(deviceType); |
| 188 | + auto backendOpts = |
| 189 | + opts.has_value() ? opts.value() : parentBackend->getBackendOptions(); |
| 190 | + backendOpts->group_name = groupName; |
| 191 | + backendOpts->timeout = |
| 192 | + timeout.has_value() ? timeout.value() : backendOpts->timeout; |
| 193 | + auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts); |
| 194 | + if (splitBackend == nullptr) { |
| 195 | + continue; |
| 196 | + } |
| 197 | + |
| 198 | + // TODO: Figure out a better way for split group desc. |
| 199 | + // TODO: We can add a new field in Backend::Options to specify the group |
| 200 | + // desc |
| 201 | + std::string groupDesc = desc.has_value() |
| 202 | + ? desc.value() |
| 203 | + : c10::str(getGroupDesc(), ":split:", incrementSplitCount()); |
| 204 | + splitBackend->setGroupDesc(groupDesc); |
| 205 | + |
| 206 | + if (!newGroup) { |
| 207 | + newGroup = c10::make_intrusive<ProcessGroup>( |
| 208 | + store_->clone(), splitBackend->getRank(), splitBackend->getSize()); |
| 209 | + newGroup->setDefaultBackend(backendType_); |
| 210 | + newGroup->setGroupName(groupName); |
| 211 | + newGroup->setGroupDesc(groupDesc); |
| 212 | + } |
| 213 | + newGroup->setBackend(deviceType, backendType, splitBackend); |
| 214 | + } |
| 215 | + |
| 216 | + return newGroup; |
| 217 | +} |
| 218 | + |
161 | 219 | } // namespace c10d |
162 | 220 |
|
163 | 221 | namespace { |
|
0 commit comments