@@ -60,34 +60,45 @@ template <typename DeviceContext, typename T>
60
60
class ConcatGradKernel : public framework ::OpKernel<T> {
61
61
public:
62
62
void Compute (const framework::ExecutionContext& ctx) const {
63
- auto * in = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
63
+ auto * out_grad =
64
+ ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
65
+ auto ins = ctx.MultiInput <framework::Tensor>(" X" );
66
+ auto out_var_names = ctx.Outputs (framework::GradVarName (" X" ));
64
67
auto outs = ctx.MultiOutput <framework::Tensor>(framework::GradVarName (" X" ));
65
68
int64_t axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
66
69
70
+ // get output tensor that the name is not kEmptyVarName
71
+ std::vector<framework::Tensor*> outputs;
72
+ for (size_t j = 0 ; j < outs.size (); ++j) {
73
+ if (out_var_names[j] != framework::kEmptyVarName ) {
74
+ outs[j]->mutable_data <T>(ctx.GetPlace ());
75
+ outputs.push_back (outs[j]);
76
+ } else {
77
+ outputs.push_back (nullptr );
78
+ }
79
+ }
80
+
67
81
// Sometimes direct copies will be faster, this maybe need deeply analysis.
68
82
if (axis == 0 && outs.size () < 10 ) {
69
83
size_t input_offset = 0 ;
70
- auto in_stride = framework::stride_numel (in ->dims ());
84
+ const auto in_stride = framework::stride_numel (out_grad ->dims ());
71
85
72
- for (auto & out : outs) {
73
- out->mutable_data <T>(ctx.GetPlace ());
74
- auto out_stride = framework::stride_numel (out->dims ());
75
- StridedNumelCopyWithAxis<T>(ctx.device_context (), axis, out->data <T>(),
76
- out_stride, in->data <T>() + input_offset,
77
- in_stride, out_stride[axis]);
86
+ for (size_t i = 0 ; i < outs.size (); ++i) {
87
+ auto out_stride = framework::stride_numel (ins[i]->dims ());
88
+ auto * out = outputs[i];
89
+ if (out != nullptr ) {
90
+ StridedNumelCopyWithAxis<T>(
91
+ ctx.device_context (), axis, out->data <T>(), out_stride,
92
+ out_grad->data <T>() + input_offset, in_stride, out_stride[axis]);
93
+ }
78
94
input_offset += out_stride[axis];
79
95
}
80
96
} else {
81
- std::vector<framework::Tensor> outputs (outs.size ());
82
- for (size_t j = 0 ; j < outs.size (); ++j) {
83
- outs[j]->mutable_data <T>(ctx.GetPlace ());
84
- outputs[j] = *outs[j];
85
- }
86
-
87
97
auto & dev_ctx = ctx.template device_context <DeviceContext>();
88
98
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
89
99
concat_grad_functor;
90
- concat_grad_functor (dev_ctx, *in, static_cast <int >(axis), &outputs);
100
+ concat_grad_functor (dev_ctx, *out_grad, ins, static_cast <int >(axis),
101
+ &outputs);
91
102
}
92
103
}
93
104
};
0 commit comments