13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/framework/details/reduce_op_handle.h"
16
- #include " paddle/fluid/framework/details/gather_op_handle.h"
17
16
#include " paddle/fluid/framework/details/reduce_and_gather.h"
18
- #include " paddle/fluid/platform/nccl_helper.h"
19
17
20
18
namespace paddle {
21
19
namespace framework {
22
20
namespace details {
23
21
24
- std::vector<VarHandle *> GetValidVarHandle (
25
- const std::vector<VarHandleBase *> &inputs) {
26
- std::vector<VarHandle *> in_var_handles;
27
- for (auto *in : inputs) {
28
- auto *in_handle = dynamic_cast <VarHandle *>(in);
29
- if (in_handle) {
30
- in_var_handles.push_back (in_handle);
31
- }
32
- }
33
- return in_var_handles;
34
- }
35
-
36
22
void ReduceOpHandle::RunImpl () {
37
23
// the input and output may have dummy var.
38
- std::vector<VarHandle *> in_var_handles = GetValidVarHandle (inputs_);
39
- std::vector<VarHandle *> out_var_handles = GetValidVarHandle (outputs_);
24
+ std::vector<VarHandle *> in_var_handles = GetValidVarHandles (inputs_);
25
+ std::vector<VarHandle *> out_var_handles = GetValidVarHandles (outputs_);
40
26
41
27
PADDLE_ENFORCE_EQ (
42
28
in_var_handles.size (), places_.size (),
@@ -45,15 +31,10 @@ void ReduceOpHandle::RunImpl() {
45
31
" The number of output should be one." );
46
32
47
33
// Wait input done, this Wait is asynchronous operation
48
- if (in_var_handles[0 ]->generated_op_ ) {
49
- for (auto *in : in_var_handles) {
50
- auto &in_p = in->place_ ;
51
- in_var_handles[0 ]->generated_op_ ->Wait (dev_ctxes_[in_p]);
52
- }
53
- }
34
+ WaitEvents (in_var_handles);
54
35
55
36
// check in the same place
56
- auto in_0_handle = static_cast <VarHandle *>( in_var_handles[0 ]) ;
37
+ auto in_0_handle = in_var_handles[0 ];
57
38
auto pre_place = in_0_handle->place_ ;
58
39
59
40
std::vector<platform::Place> in_places;
@@ -120,6 +101,7 @@ void ReduceOpHandle::RunImpl() {
120
101
for (size_t i = 0 ; i < local_scopes_.size (); ++i) {
121
102
auto &p = in_places[i];
122
103
auto &lod_tensor = lod_tensors[i];
104
+
123
105
int dev_id = boost::get<platform::CUDAPlace>(p).device ;
124
106
auto &nccl_ctx = nccl_ctxs_->at (dev_id);
125
107
auto stream = nccl_ctx.stream ();
@@ -139,18 +121,41 @@ void ReduceOpHandle::RunImpl() {
139
121
});
140
122
}
141
123
142
- platform::NCCLGroupGuard guard;
143
- for (auto &call : all_reduce_calls) {
144
- call ();
145
- }
124
+ this ->RunAndRecordEvent ([&] {
125
+ platform::NCCLGroupGuard guard;
126
+ for (auto &call : all_reduce_calls) {
127
+ call ();
128
+ }
129
+ });
146
130
#else
147
131
PADDLE_THROW (" CUDA is not support." );
148
132
#endif
149
133
} else {
150
- PADDLE_THROW (" Error " );
134
+ PADDLE_THROW (" Place should be CPUPlace or CUDAPlace. " );
151
135
}
152
136
}
153
137
}
138
+
139
+ void ReduceOpHandle::WaitEvents (
140
+ const std::vector<VarHandle *> &in_var_handles) {
141
+ if (in_var_handles[0 ]->generated_op_ ) {
142
+ for (auto *in : in_var_handles) {
143
+ in_var_handles[0 ]->generated_op_ ->Wait (dev_ctxes_[in->place_ ]);
144
+ }
145
+ }
146
+ }
147
+
148
+ std::vector<VarHandle *> ReduceOpHandle::GetValidVarHandles (
149
+ const std::vector<VarHandleBase *> &inputs) {
150
+ std::vector<VarHandle *> in_var_handles;
151
+ for (auto *in : inputs) {
152
+ auto *in_handle = dynamic_cast <VarHandle *>(in);
153
+ if (in_handle) {
154
+ in_var_handles.push_back (in_handle);
155
+ }
156
+ }
157
+ return in_var_handles;
158
+ }
154
159
std::string ReduceOpHandle::Name () const { return " reduce" ; }
155
160
} // namespace details
156
161
} // namespace framework
0 commit comments