Skip to content

Commit 69827f3

Browse files
authored
Merge pull request #11527 from jacquesqiao/concat-grad-support-data-input
concat support data as input
2 parents 210790d + ad1ad73 commit 69827f3

File tree

4 files changed

+68
-42
lines changed

4 files changed

+68
-42
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,34 +60,45 @@ template <typename DeviceContext, typename T>
6060
class ConcatGradKernel : public framework::OpKernel<T> {
6161
public:
6262
void Compute(const framework::ExecutionContext& ctx) const {
63-
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
63+
auto* out_grad =
64+
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
65+
auto ins = ctx.MultiInput<framework::Tensor>("X");
66+
auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
6467
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
6568
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
6669

70+
// get output tensor that the name is not kEmptyVarName
71+
std::vector<framework::Tensor*> outputs;
72+
for (size_t j = 0; j < outs.size(); ++j) {
73+
if (out_var_names[j] != framework::kEmptyVarName) {
74+
outs[j]->mutable_data<T>(ctx.GetPlace());
75+
outputs.push_back(outs[j]);
76+
} else {
77+
outputs.push_back(nullptr);
78+
}
79+
}
80+
6781
// Sometimes direct copies will be faster, this maybe need deeply analysis.
6882
if (axis == 0 && outs.size() < 10) {
6983
size_t input_offset = 0;
70-
auto in_stride = framework::stride_numel(in->dims());
84+
const auto in_stride = framework::stride_numel(out_grad->dims());
7185

72-
for (auto& out : outs) {
73-
out->mutable_data<T>(ctx.GetPlace());
74-
auto out_stride = framework::stride_numel(out->dims());
75-
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
76-
out_stride, in->data<T>() + input_offset,
77-
in_stride, out_stride[axis]);
86+
for (size_t i = 0; i < outs.size(); ++i) {
87+
auto out_stride = framework::stride_numel(ins[i]->dims());
88+
auto* out = outputs[i];
89+
if (out != nullptr) {
90+
StridedNumelCopyWithAxis<T>(
91+
ctx.device_context(), axis, out->data<T>(), out_stride,
92+
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
93+
}
7894
input_offset += out_stride[axis];
7995
}
8096
} else {
81-
std::vector<framework::Tensor> outputs(outs.size());
82-
for (size_t j = 0; j < outs.size(); ++j) {
83-
outs[j]->mutable_data<T>(ctx.GetPlace());
84-
outputs[j] = *outs[j];
85-
}
86-
8797
auto& dev_ctx = ctx.template device_context<DeviceContext>();
8898
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
8999
concat_grad_functor;
90-
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), &outputs);
100+
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
101+
&outputs);
91102
}
92103
}
93104
};

paddle/fluid/operators/math/concat.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,35 +70,40 @@ template <typename T>
7070
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
7171
public:
7272
void operator()(const platform::CPUDeviceContext& context,
73-
const framework::Tensor& input, const int axis,
74-
std::vector<framework::Tensor>* outputs) {
73+
const framework::Tensor& input,
74+
const std::vector<const framework::Tensor*>& ref_inputs,
75+
const int axis, std::vector<framework::Tensor*>* outputs) {
7576
// TODO(zcd): Add input data validity checking
76-
int num = outputs->size();
77+
size_t num = outputs->size();
7778

7879
int input_rows = 1;
79-
auto dim_0 = outputs->at(0).dims();
80+
auto dim_0 = ref_inputs[0]->dims();
8081
for (int i = 0; i < axis; ++i) {
8182
input_rows *= dim_0[i];
8283
}
84+
8385
int input_cols = 0;
8486

8587
std::vector<int64_t> output_cols(outputs->size());
86-
for (int i = 0; i < num; ++i) {
87-
int t_cols = outputs->at(i).numel() / input_rows;
88+
for (size_t i = 0; i < num; ++i) {
89+
int t_cols = ref_inputs[i]->numel() / input_rows;
8890
input_cols += t_cols;
8991
output_cols[i] = t_cols;
9092
}
9193
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
9294

9395
// computation
94-
for (int k = 0; k < input_rows; ++k) {
96+
for (size_t k = 0; k < input_rows; ++k) {
9597
const T* src_ptr = input.data<T>() + k * input_cols;
9698
int col_idx = 0;
9799
for (int j = 0; j < num; ++j) {
98100
int col_len = output_cols[j];
99-
T* dst_ptr = outputs->at(j).data<T>() + k * col_len;
100-
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
101-
sizeof(T) * col_len);
101+
auto* out_tensor = outputs->at(j);
102+
if (out_tensor != nullptr) {
103+
T* dst_ptr = out_tensor->data<T>() + k * col_len;
104+
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
105+
sizeof(T) * col_len);
106+
}
102107
col_idx += col_len;
103108
}
104109
}

paddle/fluid/operators/math/concat.cu

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
102102
int local_col = tid_x - curr_offset;
103103
int segment_width = curr_col_offset - curr_offset;
104104
T* output_ptr = outputs_data[curr_segment];
105-
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
106-
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
107-
output_ptr[tid_y * segment_width + local_col] =
108-
input_data[tid_y * in_col + tid_x];
105+
if (output_ptr != nullptr) {
106+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
107+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
108+
output_ptr[tid_y * segment_width + local_col] =
109+
input_data[tid_y * in_col + tid_x];
110+
}
109111
}
110112
}
111113

@@ -118,10 +120,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
118120
int split = tid_x / fixed_out_col;
119121
int in_offset = tid_x - split * fixed_out_col;
120122
T* output_ptr = outputs_data[split];
121-
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
122-
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
123-
output_ptr[tid_y * fixed_out_col + in_offset] =
124-
input_data[tid_y * in_col + tid_x];
123+
if (output_ptr != nullptr) {
124+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
125+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
126+
output_ptr[tid_y * fixed_out_col + in_offset] =
127+
input_data[tid_y * in_col + tid_x];
128+
}
125129
}
126130
}
127131

@@ -203,17 +207,18 @@ template <typename T>
203207
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
204208
public:
205209
void operator()(const platform::CUDADeviceContext& context,
206-
const framework::Tensor& input, const int axis,
207-
std::vector<framework::Tensor>* outputs) {
210+
const framework::Tensor& input,
211+
const std::vector<const framework::Tensor*>& ref_inputs,
212+
const int axis, std::vector<framework::Tensor*>* outputs) {
208213
// TODO(zcd): Add input data validity checking
209214
int o_num = outputs->size();
210215
int out_row = 1;
211-
auto dim_0 = outputs->at(0).dims();
216+
auto dim_0 = ref_inputs[0]->dims();
212217
for (int i = 0; i < axis; ++i) {
213218
out_row *= dim_0[i];
214219
}
215220

216-
int out_col = outputs->at(0).numel() / out_row;
221+
int out0_col = ref_inputs[0]->numel() / out_row;
217222
int in_col = 0, in_row = out_row;
218223
bool sameShape = true;
219224

@@ -223,13 +228,17 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
223228

224229
outputs_cols[0] = 0;
225230
for (int i = 0; i < o_num; ++i) {
226-
int t_col = outputs->at(i).numel() / out_row;
231+
int t_col = outputs->at(i)->numel() / out_row;
227232
if (sameShape) {
228-
if (t_col != out_col) sameShape = false;
233+
if (t_col != out0_col) sameShape = false;
229234
}
230235
in_col += t_col;
231236
outputs_cols[i + 1] = in_col;
232-
outputs_ptr[i] = outputs->at(i).data<T>();
237+
if (outputs->at(i) != nullptr) {
238+
outputs_ptr[i] = outputs->at(i)->data<T>();
239+
} else {
240+
outputs_ptr[i] = nullptr;
241+
}
233242
}
234243

235244
T** dev_out_gpu_data =
@@ -255,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
255264

256265
if (sameShape) {
257266
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
258-
input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
267+
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
259268
} else {
260269
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
261270
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(

paddle/fluid/operators/math/concat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ template <typename DeviceContext, typename T>
5757
class ConcatGradFunctor {
5858
public:
5959
void operator()(const DeviceContext& context, const framework::Tensor& input,
60-
const int axis, std::vector<framework::Tensor>* outputs);
60+
const std::vector<const framework::Tensor*>& ref_inputs,
61+
const int axis, std::vector<framework::Tensor*>* outputs);
6162
};
6263

6364
} // namespace math

0 commit comments

Comments
 (0)