@@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel<T> {
33
33
auto place = ctx.GetPlace ();
34
34
out->mutable_data <T>(place);
35
35
36
- // TODO(zcd): Sometimes direct copies will be faster
37
- std::vector<framework::Tensor> inputs (ins.size ());
38
- for (size_t j = 0 ; j < ins.size (); ++j) {
39
- inputs[j] = *ins[j];
36
+ // Sometimes direct copies will be faster, this maybe need deeply analysis.
37
+ if (axis == 0 && ins.size () < 10 ) {
38
+ size_t output_offset = 0 ;
39
+ for (auto * in : ins) {
40
+ auto in_stride = framework::stride_numel (in->dims ());
41
+ auto out_stride = framework::stride_numel (out->dims ());
42
+ StridedNumelCopyWithAxis<T>(ctx.device_context (), axis,
43
+ out->data <T>() + output_offset, out_stride,
44
+ in->data <T>(), in_stride, in_stride[axis]);
45
+ output_offset += in_stride[axis];
46
+ }
47
+ } else {
48
+ std::vector<framework::Tensor> inputs (ins.size ());
49
+ for (size_t j = 0 ; j < ins.size (); ++j) {
50
+ inputs[j] = *ins[j];
51
+ }
52
+ auto & dev_ctx = ctx.template device_context <DeviceContext>();
53
+ paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
54
+ concat_functor (dev_ctx, inputs, static_cast <int >(axis), out);
40
55
}
41
- auto & dev_ctx = ctx.template device_context <DeviceContext>();
42
- paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
43
- concat_functor (dev_ctx, inputs, static_cast <int >(axis), out);
44
56
}
45
57
};
46
58
@@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
52
64
auto outs = ctx.MultiOutput <framework::Tensor>(framework::GradVarName (" X" ));
53
65
int64_t axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
54
66
55
- // TODO(zcd): Sometimes direct copies will be faster
56
- std::vector<framework::Tensor> outputs (outs.size ());
57
- for (size_t j = 0 ; j < outs.size (); ++j) {
58
- outs[j]->mutable_data <T>(ctx.GetPlace ());
59
- outputs[j] = *outs[j];
60
- }
67
+ // Sometimes direct copies will be faster, this maybe need deeply analysis.
68
+ if (axis == 0 && outs.size () < 10 ) {
69
+ size_t input_offset = 0 ;
70
+ auto in_stride = framework::stride_numel (in->dims ());
71
+
72
+ for (auto & out : outs) {
73
+ out->mutable_data <T>(ctx.GetPlace ());
74
+ auto out_stride = framework::stride_numel (out->dims ());
75
+ StridedNumelCopyWithAxis<T>(ctx.device_context (), axis, out->data <T>(),
76
+ out_stride, in->data <T>() + input_offset,
77
+ in_stride, out_stride[axis]);
78
+ input_offset += out_stride[axis];
79
+ }
80
+ } else {
81
+ std::vector<framework::Tensor> outputs (outs.size ());
82
+ for (size_t j = 0 ; j < outs.size (); ++j) {
83
+ outs[j]->mutable_data <T>(ctx.GetPlace ());
84
+ outputs[j] = *outs[j];
85
+ }
61
86
62
- auto & dev_ctx = ctx.template device_context <DeviceContext>();
63
- paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
64
- concat_grad_functor;
65
- concat_grad_functor (dev_ctx, *in, static_cast <int >(axis), outputs);
87
+ auto & dev_ctx = ctx.template device_context <DeviceContext>();
88
+ paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
89
+ concat_grad_functor;
90
+ concat_grad_functor (dev_ctx, *in, static_cast <int >(axis), outputs);
91
+ }
66
92
}
67
93
};
68
94
0 commit comments