@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
10
10
limitations under the License. */
11
11
12
12
#include < functional>
13
+ #include < unordered_map>
13
14
14
15
#include " paddle/fluid/framework/lod_tensor.h"
15
16
#include " paddle/fluid/framework/op_registry.h"
@@ -37,37 +38,43 @@ class NCCLTypeWrapper<double> {
37
38
static const ncclDataType_t type = ncclDouble;
38
39
};
39
40
41
+ static ncclRedOp_t str_to_nccl_red_type (std::string reduction) {
42
+ static const std::unordered_map<std::string, ncclRedOp_t> str_to_type = {
43
+ {" ncclSum" , ncclSum},
44
+ {" ncclMin" , ncclMin},
45
+ {" ncclMax" , ncclMax},
46
+ {" ncclProd" , ncclProd},
47
+ };
48
+ auto it = str_to_type.find (reduction);
49
+ PADDLE_ENFORCE_EQ (it != str_to_type.end (), true ,
50
+ platform::errors::InvalidArgument (
51
+ " Invalid nccl reduction. Must be ncclMin | ncclMax | "
52
+ " ncclProd | ncclSum" ));
53
+ return it->second ;
54
+ }
55
+
40
56
template <typename T>
41
57
class NCCLAllReduceKernel : public framework ::OpKernel<T> {
42
58
public:
43
59
void Compute (const framework::ExecutionContext& ctx) const override {
44
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
45
- " This kernel only runs on GPU device." );
60
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
61
+ platform::errors::PreconditionNotMet (
62
+ " This kernel only runs on GPU device." ));
46
63
auto * x = ctx.Input <LoDTensor>(" X" );
47
64
auto * out = ctx.Output <LoDTensor>(" Out" );
48
65
auto * comm = ctx.Input <Communicator>(" Communicator" );
49
66
std::string reduction = ctx.Attr <std::string>(" reduction" );
50
67
51
- ncclRedOp_t reduction_op_ = ncclSum;
52
- if (reduction == " ncclMin" ) {
53
- reduction_op_ = ncclMin;
54
- } else if (reduction == " ncclMax" ) {
55
- reduction_op_ = ncclMax;
56
- } else if (reduction == " ncclSum" ) {
57
- reduction_op_ = ncclSum;
58
- } else if (reduction == " ncclProd" ) {
59
- reduction_op_ = ncclProd;
60
- } else {
61
- PADDLE_THROW (" Invalid reduction. default ncclSum." );
62
- }
68
+ auto reduction_op_ = str_to_nccl_red_type (reduction);
69
+
63
70
// device id
64
71
int gpu_id =
65
72
BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ()).GetDeviceId ();
66
73
int idx = comm->GetCommId (gpu_id);
67
74
VLOG (3 ) << " gpu : "
68
75
<< " invoke allreduce. send " << x->numel () << " recv "
69
76
<< out->numel ();
70
- PADDLE_ENFORCE (platform::dynload::ncclAllReduce (
77
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclAllReduce (
71
78
x->data <T>(), out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
72
79
NCCLTypeWrapper<T>::type, reduction_op_, comm->comms ().at (idx),
73
80
ctx.cuda_device_context ().stream ()));
@@ -81,26 +88,17 @@ template <typename T>
81
88
class NCCLReduceKernel : public framework ::OpKernel<T> {
82
89
public:
83
90
void Compute (const framework::ExecutionContext& ctx) const override {
84
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
85
- " This kernel only runs on GPU device." );
91
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
92
+ platform::errors::InvalidArgument (
93
+ " This kernel only runs on GPU device." ));
86
94
auto x = ctx.Input <LoDTensor>(" X" ); // x0, x1, x2
87
95
auto out = ctx.Output <LoDTensor>(" Out" );
88
96
auto * comm = ctx.Input <Communicator>(" Communicator" );
89
97
int root = ctx.Attr <int >(" root" );
90
98
std::string reduction = ctx.Attr <std::string>(" reduction" );
91
99
92
- ncclRedOp_t reduction_op_ = ncclSum;
93
- if (reduction == " ncclMin" ) {
94
- reduction_op_ = ncclMin;
95
- } else if (reduction == " ncclMax" ) {
96
- reduction_op_ = ncclMax;
97
- } else if (reduction == " ncclSum" ) {
98
- reduction_op_ = ncclSum;
99
- } else if (reduction == " ncclProd" ) {
100
- reduction_op_ = ncclProd;
101
- } else {
102
- PADDLE_THROW (" Invalid reduction. default ncclSum." );
103
- }
100
+ auto reduction_op_ = str_to_nccl_red_type (reduction);
101
+
104
102
// device id
105
103
int gpu_id =
106
104
BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ()).GetDeviceId ();
@@ -113,7 +111,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
113
111
}
114
112
VLOG (3 ) << " gpu : " << gpu_id << " invoke reduce. send " << x->numel ()
115
113
<< " recv " << out->numel ();
116
- PADDLE_ENFORCE (platform::dynload::ncclReduce (
114
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclReduce (
117
115
x->data <T>(), recvbuffer, x->numel (), NCCLTypeWrapper<T>::type,
118
116
reduction_op_, root, comm->comms ().at (idx),
119
117
ctx.cuda_device_context ().stream ()));
@@ -126,8 +124,9 @@ template <typename T>
126
124
class NCCLBcastKernel : public framework ::OpKernel<T> {
127
125
public:
128
126
void Compute (const framework::ExecutionContext& ctx) const override {
129
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
130
- " This kernel only runs on GPU device." );
127
+ PADDLE_ENFORCE_EQ (platform::is_gpu_place (ctx.GetPlace ()), true ,
128
+ platform::errors::InvalidArgument (
129
+ " This kernel only runs on GPU device." ));
131
130
int root = ctx.Attr <int >(" root" );
132
131
auto * comm = ctx.Input <Communicator>(" Communicator" );
133
132
// device id
@@ -137,7 +136,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
137
136
if (idx == root) {
138
137
auto * x = ctx.Input <LoDTensor>(" X" );
139
138
VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. send " << x->numel ();
140
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
139
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclBcast (
141
140
reinterpret_cast <void *>(const_cast <T*>(x->data <T>())), x->numel (),
142
141
NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx),
143
142
ctx.cuda_device_context ().stream ()));
@@ -146,7 +145,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
146
145
auto * out = ctx.Output <LoDTensor>(" Out" );
147
146
VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. recv buffer "
148
147
<< framework::product (out->dims ());
149
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
148
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::ncclBcast (
150
149
out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
151
150
NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx),
152
151
ctx.cuda_device_context ().stream ()));
0 commit comments