Skip to content

Commit c3864ea

Browse files
committed
if axis == 0; directly copy D->D
1 parent 131ec27 commit c3864ea

File tree

1 file changed

+43
-17
lines changed

1 file changed

+43
-17
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel<T> {
3333
auto place = ctx.GetPlace();
3434
out->mutable_data<T>(place);
3535

36-
// TODO(zcd): Sometimes direct copies will be faster
37-
std::vector<framework::Tensor> inputs(ins.size());
38-
for (size_t j = 0; j < ins.size(); ++j) {
39-
inputs[j] = *ins[j];
36+
// Sometimes direct copies will be faster, this maybe need deeply analysis.
37+
if (axis == 0 && ins.size() < 10) {
38+
size_t output_offset = 0;
39+
for (auto* in : ins) {
40+
auto in_stride = framework::stride_numel(in->dims());
41+
auto out_stride = framework::stride_numel(out->dims());
42+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
43+
out->data<T>() + output_offset, out_stride,
44+
in->data<T>(), in_stride, in_stride[axis]);
45+
output_offset += in_stride[axis];
46+
}
47+
} else {
48+
std::vector<framework::Tensor> inputs(ins.size());
49+
for (size_t j = 0; j < ins.size(); ++j) {
50+
inputs[j] = *ins[j];
51+
}
52+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
53+
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
54+
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
4055
}
41-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
42-
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
43-
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
4456
}
4557
};
4658

@@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5264
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
5365
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
5466

55-
// TODO(zcd): Sometimes direct copies will be faster
56-
std::vector<framework::Tensor> outputs(outs.size());
57-
for (size_t j = 0; j < outs.size(); ++j) {
58-
outs[j]->mutable_data<T>(ctx.GetPlace());
59-
outputs[j] = *outs[j];
60-
}
67+
// Sometimes direct copies will be faster, this maybe need deeply analysis.
68+
if (axis == 0 && outs.size() < 10) {
69+
size_t input_offset = 0;
70+
auto in_stride = framework::stride_numel(in->dims());
71+
72+
for (auto& out : outs) {
73+
out->mutable_data<T>(ctx.GetPlace());
74+
auto out_stride = framework::stride_numel(out->dims());
75+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
76+
out_stride, in->data<T>() + input_offset,
77+
in_stride, out_stride[axis]);
78+
input_offset += out_stride[axis];
79+
}
80+
} else {
81+
std::vector<framework::Tensor> outputs(outs.size());
82+
for (size_t j = 0; j < outs.size(); ++j) {
83+
outs[j]->mutable_data<T>(ctx.GetPlace());
84+
outputs[j] = *outs[j];
85+
}
6186

62-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
63-
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
64-
concat_grad_functor;
65-
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
87+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
88+
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
89+
concat_grad_functor;
90+
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
91+
}
6692
}
6793
};
6894

0 commit comments

Comments
 (0)