@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
10
10
limitations under the License. */
11
11
12
12
#include < functional>
13
+ #include < unordered_map>
13
14
14
15
#include " paddle/fluid/framework/lod_tensor.h"
15
16
#include " paddle/fluid/framework/op_registry.h"
@@ -37,36 +38,42 @@ class NCCLTypeWrapper<double> {
37
38
static const ncclDataType_t type = ncclDouble;
38
39
};
39
40
41
+ static ncclRedOp_t str_to_nccl_red_type (std::string reduction) {
42
+ static const std::unordered_map<std::string, ncclRedOp_t> str_to_type = {
43
+ {" ncclSum" , ncclSum},
44
+ {" ncclMin" , ncclMin},
45
+ {" ncclMax" , ncclMax},
46
+ {" ncclProd" , ncclProd},
47
+ };
48
+ auto it = str_to_type.find (reduction);
49
+ PADDLE_ENFORCE_EQ (it != str_to_type.end (), true ,
50
+ platform::errors::InvalidArgument (
51
+ " Invalid nccl reduction. Must be ncclMin | ncclMax | "
52
+ " ncclProd | ncclSum" ));
53
+ return it->second ;
54
+ }
55
+
40
56
template <typename T>
41
57
class NCCLAllReduceKernel : public framework ::OpKernel<T> {
42
58
public:
43
59
void Compute (const framework::ExecutionContext& ctx) const override {
44
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
45
- " This kernel only runs on GPU device." );
60
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
61
+ platform::errors::PreconditionNotMet (
62
+ " This kernel only runs on GPU device." ));
46
63
auto * x = ctx.Input <LoDTensor>(" X" );
47
64
auto * out = ctx.Output <LoDTensor>(" Out" );
48
65
auto * comm = ctx.Input <Communicator>(" Communicator" );
49
66
std::string reduction = ctx.Attr <std::string>(" reduction" );
50
67
51
- ncclRedOp_t reduction_op_ = ncclSum;
52
- if (reduction == " ncclMin" ) {
53
- reduction_op_ = ncclMin;
54
- } else if (reduction == " ncclMax" ) {
55
- reduction_op_ = ncclMax;
56
- } else if (reduction == " ncclSum" ) {
57
- reduction_op_ = ncclSum;
58
- } else if (reduction == " ncclProd" ) {
59
- reduction_op_ = ncclProd;
60
- } else {
61
- PADDLE_THROW (" Invalid reduction. default ncclSum." );
62
- }
68
+ auto reduction_op_ = str_to_nccl_red_type (reduction);
69
+
63
70
// device id
64
71
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).GetDeviceId ();
65
72
int idx = comm->GetCommId (gpu_id);
66
73
VLOG (3 ) << " gpu : "
67
74
<< " invoke allreduce. send " << x->numel () << " recv "
68
75
<< out->numel ();
69
- PADDLE_ENFORCE (platform::dynload::ncclAllReduce (
76
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclAllReduce (
70
77
x->data <T>(), out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
71
78
NCCLTypeWrapper<T>::type, reduction_op_, comm->comms ().at (idx),
72
79
ctx.cuda_device_context ().stream ()));
@@ -80,26 +87,17 @@ template <typename T>
80
87
class NCCLReduceKernel : public framework ::OpKernel<T> {
81
88
public:
82
89
void Compute (const framework::ExecutionContext& ctx) const override {
83
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
84
- " This kernel only runs on GPU device." );
90
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
91
+ platform::errors::InvalidArgument (
92
+ " This kernel only runs on GPU device." ));
85
93
auto x = ctx.Input <LoDTensor>(" X" ); // x0, x1, x2
86
94
auto out = ctx.Output <LoDTensor>(" Out" );
87
95
auto * comm = ctx.Input <Communicator>(" Communicator" );
88
96
int root = ctx.Attr <int >(" root" );
89
97
std::string reduction = ctx.Attr <std::string>(" reduction" );
90
98
91
- ncclRedOp_t reduction_op_ = ncclSum;
92
- if (reduction == " ncclMin" ) {
93
- reduction_op_ = ncclMin;
94
- } else if (reduction == " ncclMax" ) {
95
- reduction_op_ = ncclMax;
96
- } else if (reduction == " ncclSum" ) {
97
- reduction_op_ = ncclSum;
98
- } else if (reduction == " ncclProd" ) {
99
- reduction_op_ = ncclProd;
100
- } else {
101
- PADDLE_THROW (" Invalid reduction. default ncclSum." );
102
- }
99
+ auto reduction_op_ = str_to_nccl_red_type (reduction);
100
+
103
101
// device id
104
102
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).GetDeviceId ();
105
103
int idx = comm->GetCommId (gpu_id);
@@ -111,7 +109,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
111
109
}
112
110
VLOG (3 ) << " gpu : " << gpu_id << " invoke reduce. send " << x->numel ()
113
111
<< " recv " << out->numel ();
114
- PADDLE_ENFORCE (platform::dynload::ncclReduce (
112
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclReduce (
115
113
x->data <T>(), recvbuffer, x->numel (), NCCLTypeWrapper<T>::type,
116
114
reduction_op_, root, comm->comms ().at (idx),
117
115
ctx.cuda_device_context ().stream ()));
@@ -124,8 +122,9 @@ template <typename T>
124
122
class NCCLBcastKernel : public framework ::OpKernel<T> {
125
123
public:
126
124
void Compute (const framework::ExecutionContext& ctx) const override {
127
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
128
- " This kernel only runs on GPU device." );
125
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
126
+ platform::errors::InvalidArgument (
127
+ " This kernel only runs on GPU device." ));
129
128
int root = ctx.Attr <int >(" root" );
130
129
auto * comm = ctx.Input <Communicator>(" Communicator" );
131
130
// device id
@@ -134,7 +133,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
134
133
if (idx == root) {
135
134
auto * x = ctx.Input <LoDTensor>(" X" );
136
135
VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. send " << x->numel ();
137
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
136
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclBcast (
138
137
reinterpret_cast <void *>(const_cast <T*>(x->data <T>())), x->numel (),
139
138
NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx),
140
139
ctx.cuda_device_context ().stream ()));
@@ -143,7 +142,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
143
142
auto * out = ctx.Output <LoDTensor>(" Out" );
144
143
VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. recv buffer "
145
144
<< framework::product (out->dims ());
146
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
145
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclBcast (
147
146
out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
148
147
NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx),
149
148
ctx.cuda_device_context ().stream ()));
0 commit comments