Skip to content

Commit 129859e

Browse files
authored
Support data type int64 in NCCL. (#9818)
1 parent 1d88ebe commit 129859e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

paddle/fluid/platform/nccl_helper.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
#pragma once
1616

17-
#include <thread>
17+
#include <thread> // NOLINT
1818
#include <typeindex>
19+
#include <vector>
1920
#include "paddle/fluid/platform/dynload/nccl.h"
2021
#include "paddle/fluid/platform/enforce.h"
2122

@@ -29,6 +30,8 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
2930
return ncclDouble;
3031
} else if (type == typeid(int)) { // NOLINT
3132
return ncclInt;
33+
} else if (type == typeid(int64_t)) { // NOLINT
34+
return ncclInt64;
3235
} else {
3336
PADDLE_THROW("Not supported");
3437
}
@@ -66,23 +69,23 @@ struct NCCLContext {
6669
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
6770
}
6871

69-
static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
72+
static void InitNCCLContext(std::unordered_map<int, NCCLContext> *contexts,
7073
const std::vector<platform::Place> &places) {
7174
std::vector<ncclComm_t> comms;
7275
std::vector<int> devs;
73-
comms.resize(contexts.size());
74-
devs.reserve(contexts.size());
76+
comms.resize(contexts->size());
77+
devs.reserve(contexts->size());
7578

7679
for (auto &p : places) {
7780
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
7881
}
7982

8083
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
81-
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
84+
&comms[0], static_cast<int>(contexts->size()), &devs[0]));
8285

8386
int i = 0;
8487
for (auto &dev_id : devs) {
85-
contexts.at(dev_id).comm_ = comms[i++];
88+
contexts->at(dev_id).comm_ = comms[i++];
8689
}
8790
}
8891
};
@@ -91,7 +94,7 @@ struct NCCLContextMap {
9194
std::unordered_map<int, NCCLContext> contexts_;
9295
std::vector<int> order_;
9396

94-
NCCLContextMap(const std::vector<platform::Place> &places) {
97+
explicit NCCLContextMap(const std::vector<platform::Place> &places) {
9598
order_.reserve(places.size());
9699
for (auto &p : places) {
97100
int dev_id = boost::get<CUDAPlace>(p).device;

0 commit comments

Comments
 (0)