@@ -84,7 +84,7 @@ void GatherOpHandle::RunImpl() {
84
84
" The type of input is not consistent." );
85
85
PADDLE_ENFORCE_EQ (pre_in.height (), in_sr.height (),
86
86
" The height of inputs is not consistent." );
87
- PADDLE_ENFORCE_EQ (pre_in.GetCompleteDims (), in_sr.GetCompleteDims (), ,
87
+ PADDLE_ENFORCE_EQ (pre_in.GetCompleteDims (), in_sr.GetCompleteDims (),
88
88
" The dims of inputs is not consistent." );
89
89
90
90
auto in_sr_rows = in_sr.rows ();
@@ -110,14 +110,17 @@ void GatherOpHandle::RunImpl() {
110
110
Tensor *out_tensor = out->mutable_value ();
111
111
112
112
// copy
113
- int s = 0 , e = 0 ;
114
- for (size_t j = 0 ; j < in_tensors.size (); ++j) {
115
- e += in_tensors[j].dims ()[0 ];
116
- auto sub_out = out_tensor->Slice (s, e);
117
- paddle::framework::TensorCopy (in_tensors[j], out_place,
118
- *(dev_ctxes_[in_places[j]]), &sub_out);
119
- s = e;
120
- }
113
+ auto dev_ctx = dev_ctxes_[out_place];
114
+ RunAndRecordEvent (out_place, [in_tensors, out_var, dev_ctx, out_place] {
115
+ int s = 0 , e = 0 ;
116
+ for (size_t j = 0 ; j < in_tensors.size (); ++j) {
117
+ e += in_tensors[j].dims ()[0 ];
118
+ auto sub_out = out_tensor->Slice (s, e);
119
+ paddle::framework::TensorCopy (in_tensors[j], out_place, *(dev_ctx),
120
+ &sub_out);
121
+ s = e;
122
+ }
123
+ });
121
124
}
122
125
123
126
std::string GatherOpHandle::Name () const { return " gather" ; }
0 commit comments