@@ -17,6 +17,7 @@ limitations under the License. */
17
17
#include < utility>
18
18
#include < vector>
19
19
#include " paddle/fluid/framework/op_registry.h"
20
+ #include " paddle/fluid/operators/math/concat.h"
20
21
#include " paddle/fluid/operators/strided_memcpy.h"
21
22
22
23
namespace paddle {
@@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel<T> {
27
28
public:
28
29
void Compute (const framework::ExecutionContext& ctx) const override {
29
30
auto ins = ctx.MultiInput <framework::Tensor>(" X" );
30
- auto * out = ctx.Output <framework::Tensor>(" Out" );
31
+ framework::Tensor * out = ctx.Output <framework::Tensor>(" Out" );
31
32
int64_t axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
32
33
auto place = ctx.GetPlace ();
33
34
out->mutable_data <T>(place);
34
35
35
- auto out_stride = framework::stride_numel (out->dims ());
36
-
37
- size_t output_offset = 0 ;
38
-
39
- // If axis >=1, copy to out immediately need to call many times
40
- // of cuda memcpy. Copy the input to cpu and do the stride copy,
41
- // then copy to gpu output.
42
-
43
- if (platform::is_gpu_place (place) && axis >= 1 ) {
44
- platform::CPUPlace copy_place;
45
- auto & cpu_ctx = *platform::DeviceContextPool::Instance ().Get (copy_place);
46
- framework::Tensor cpu_out;
47
- cpu_out.Resize (out->dims ());
48
- cpu_out.mutable_data <T>(copy_place);
49
- auto & dev_ctx = ctx.device_context ();
50
- std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
51
- for (auto * in : ins) {
52
- std::unique_ptr<framework::Tensor> cpu_in (new framework::Tensor);
53
- framework::TensorCopy (*in, copy_place, dev_ctx, cpu_in.get ());
54
- cpu_ins.emplace_back (std::move (cpu_in));
55
- }
56
- // TODO(dzhwinter): overlap copy and compute stream
57
- // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
58
- dev_ctx.Wait ();
59
-
60
- for (auto & in : cpu_ins) {
61
- auto & cpu_in = *in.get ();
62
- auto in_stride = framework::stride_numel (cpu_in.dims ());
63
-
64
- StridedNumelCopyWithAxis<T>(
65
- cpu_ctx, axis, cpu_out.data <T>() + output_offset, out_stride,
66
- cpu_in.data <T>(), in_stride, in_stride[axis]);
67
- output_offset += in_stride[axis];
68
- }
69
- framework::TensorCopy (cpu_out, place, dev_ctx, out);
70
- } else {
36
+ // Sometimes direct copies will be faster, this maybe need deeply analysis.
37
+ if (axis == 0 && ins.size () < 10 ) {
38
+ size_t output_offset = 0 ;
71
39
for (auto * in : ins) {
72
40
auto in_stride = framework::stride_numel (in->dims ());
41
+ auto out_stride = framework::stride_numel (out->dims ());
73
42
StridedNumelCopyWithAxis<T>(ctx.device_context (), axis,
74
43
out->data <T>() + output_offset, out_stride,
75
44
in->data <T>(), in_stride, in_stride[axis]);
76
45
output_offset += in_stride[axis];
77
46
}
47
+ } else {
48
+ std::vector<framework::Tensor> inputs (ins.size ());
49
+ for (size_t j = 0 ; j < ins.size (); ++j) {
50
+ inputs[j] = *ins[j];
51
+ }
52
+ auto & dev_ctx = ctx.template device_context <DeviceContext>();
53
+ paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
54
+ concat_functor (dev_ctx, inputs, static_cast <int >(axis), out);
78
55
}
79
56
}
80
57
};
@@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
86
63
auto * in = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
87
64
auto outs = ctx.MultiOutput <framework::Tensor>(framework::GradVarName (" X" ));
88
65
int64_t axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
89
- size_t input_offset = 0 ;
90
- auto in_stride = framework::stride_numel (in->dims ());
91
66
92
- for (auto & out : outs) {
93
- out->mutable_data <T>(ctx.GetPlace ());
94
- auto out_stride = framework::stride_numel (out->dims ());
95
- StridedNumelCopyWithAxis<T>(ctx.device_context (), axis, out->data <T>(),
96
- out_stride, in->data <T>() + input_offset,
97
- in_stride, out_stride[axis]);
98
- input_offset += out_stride[axis];
67
+ // Sometimes direct copies will be faster, this maybe need deeply analysis.
68
+ if (axis == 0 && outs.size () < 10 ) {
69
+ size_t input_offset = 0 ;
70
+ auto in_stride = framework::stride_numel (in->dims ());
71
+
72
+ for (auto & out : outs) {
73
+ out->mutable_data <T>(ctx.GetPlace ());
74
+ auto out_stride = framework::stride_numel (out->dims ());
75
+ StridedNumelCopyWithAxis<T>(ctx.device_context (), axis, out->data <T>(),
76
+ out_stride, in->data <T>() + input_offset,
77
+ in_stride, out_stride[axis]);
78
+ input_offset += out_stride[axis];
79
+ }
80
+ } else {
81
+ std::vector<framework::Tensor> outputs (outs.size ());
82
+ for (size_t j = 0 ; j < outs.size (); ++j) {
83
+ outs[j]->mutable_data <T>(ctx.GetPlace ());
84
+ outputs[j] = *outs[j];
85
+ }
86
+
87
+ auto & dev_ctx = ctx.template device_context <DeviceContext>();
88
+ paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
89
+ concat_grad_functor;
90
+ concat_grad_functor (dev_ctx, *in, static_cast <int >(axis), outputs);
99
91
}
100
92
}
101
93
};
0 commit comments