@@ -102,10 +102,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
102
102
int local_col = tid_x - curr_offset;
103
103
int segment_width = curr_col_offset - curr_offset;
104
104
T* output_ptr = outputs_data[curr_segment];
105
- int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
106
- for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
107
- output_ptr[tid_y * segment_width + local_col] =
108
- input_data[tid_y * in_col + tid_x];
105
+ if (output_ptr != nullptr ) {
106
+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
107
+ for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
108
+ output_ptr[tid_y * segment_width + local_col] =
109
+ input_data[tid_y * in_col + tid_x];
110
+ }
109
111
}
110
112
}
111
113
@@ -118,10 +120,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
118
120
int split = tid_x / fixed_out_col;
119
121
int in_offset = tid_x - split * fixed_out_col;
120
122
T* output_ptr = outputs_data[split];
121
- int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
122
- for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
123
- output_ptr[tid_y * fixed_out_col + in_offset] =
124
- input_data[tid_y * in_col + tid_x];
123
+ if (output_ptr != nullptr ) {
124
+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
125
+ for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
126
+ output_ptr[tid_y * fixed_out_col + in_offset] =
127
+ input_data[tid_y * in_col + tid_x];
128
+ }
125
129
}
126
130
}
127
131
@@ -203,17 +207,18 @@ template <typename T>
203
207
class ConcatGradFunctor <platform::CUDADeviceContext, T> {
204
208
public:
205
209
void operator ()(const platform::CUDADeviceContext& context,
206
- const framework::Tensor& input, const int axis,
207
- std::vector<framework::Tensor>* outputs) {
210
+ const framework::Tensor& input,
211
+ const std::vector<const framework::Tensor*>& ref_inputs,
212
+ const int axis, std::vector<framework::Tensor*>* outputs) {
208
213
// TODO(zcd): Add input data validity checking
209
214
int o_num = outputs->size ();
210
215
int out_row = 1 ;
211
- auto dim_0 = outputs-> at ( 0 ). dims ();
216
+ auto dim_0 = ref_inputs[ 0 ]-> dims ();
212
217
for (int i = 0 ; i < axis; ++i) {
213
218
out_row *= dim_0[i];
214
219
}
215
220
216
- int out_col = outputs-> at ( 0 ). numel () / out_row;
221
+ int out0_col = ref_inputs[ 0 ]-> numel () / out_row;
217
222
int in_col = 0 , in_row = out_row;
218
223
bool sameShape = true ;
219
224
@@ -223,13 +228,17 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
223
228
224
229
outputs_cols[0 ] = 0 ;
225
230
for (int i = 0 ; i < o_num; ++i) {
226
- int t_col = outputs->at (i). numel () / out_row;
231
+ int t_col = outputs->at (i)-> numel () / out_row;
227
232
if (sameShape) {
228
- if (t_col != out_col ) sameShape = false ;
233
+ if (t_col != out0_col ) sameShape = false ;
229
234
}
230
235
in_col += t_col;
231
236
outputs_cols[i + 1 ] = in_col;
232
- outputs_ptr[i] = outputs->at (i).data <T>();
237
+ if (outputs->at (i) != nullptr ) {
238
+ outputs_ptr[i] = outputs->at (i)->data <T>();
239
+ } else {
240
+ outputs_ptr[i] = nullptr ;
241
+ }
233
242
}
234
243
235
244
T** dev_out_gpu_data =
@@ -255,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
255
264
256
265
if (sameShape) {
257
266
KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
258
- input.data <T>(), in_row, in_col, out_col , dev_out_gpu_data);
267
+ input.data <T>(), in_row, in_col, out0_col , dev_out_gpu_data);
259
268
} else {
260
269
const int * dev_outs_col_data = outputs_cols.CUDAData (context.GetPlace ());
261
270
KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
0 commit comments