Skip to content

Commit a749765

Browse files
author
chengduo
authored
Refine Split op (#13967)
* speedup split_op test=develop * speedup split_op test=develop * rename ConcatGrad to Split * refine concat and split test=develop * fix compile error
1 parent 96e9b65 commit a749765

14 files changed

+89
-75
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
284284
op_library(sequence_conv_op DEPS context_project)
285285
op_library(sequence_pool_op DEPS sequence_pooling)
286286
if (NOT WIN32)
287-
op_library(lstm_op DEPS sequence2batch lstm_compute)
288-
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
289-
op_library(lstmp_op DEPS sequence2batch lstm_compute)
290-
op_library(gru_op DEPS sequence2batch gru_compute)
287+
op_library(lstm_op DEPS sequence2batch lstm_compute)
288+
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
289+
op_library(lstmp_op DEPS sequence2batch lstm_compute)
290+
op_library(gru_op DEPS sequence2batch gru_compute)
291291
endif(NOT WIN32)
292292
op_library(recurrent_op DEPS executor)
293293
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
@@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor)
316316
op_library(load_op DEPS lod_tensor)
317317
op_library(save_combine_op DEPS lod_tensor)
318318
op_library(load_combine_op DEPS lod_tensor)
319-
op_library(concat_op DEPS concat)
319+
op_library(concat_op DEPS concat_and_split)
320320

321321
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
322322

@@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
348348
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
349349
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
350350
if(NOT WIN32)
351-
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
351+
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
352352
endif()
353353
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)

paddle/fluid/operators/array_to_lod_tensor_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <paddle/fluid/operators/math/concat.h>
14+
#include <paddle/fluid/operators/math/concat_and_split.h>
1515
#include <numeric>
1616

1717
#include "paddle/fluid/framework/lod_rank_table.h"

paddle/fluid/operators/concat_op.h

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include <utility>
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
20-
#include "paddle/fluid/operators/math/concat.h"
20+
#include "paddle/fluid/operators/math/concat_and_split.h"
2121
#include "paddle/fluid/operators/strided_memcpy.h"
2222

2323
namespace paddle {
@@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
8989
outputs.push_back(nullptr);
9090
}
9191
}
92+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
9293

9394
// Sometimes direct copies will be faster, this maybe need deeply analysis.
9495
if (axis == 0 && outs.size() < 10) {
95-
size_t input_offset = 0;
96-
const auto in_stride = framework::stride_numel(out_grad->dims());
97-
98-
for (size_t i = 0; i < outs.size(); ++i) {
99-
auto out_stride = framework::stride_numel(ins[i]->dims());
100-
auto* out = outputs[i];
101-
if (out != nullptr) {
102-
StridedNumelCopyWithAxis<T>(
103-
ctx.device_context(), axis, out->data<T>(), out_stride,
104-
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
105-
}
106-
input_offset += out_stride[axis];
107-
}
96+
std::vector<const framework::Tensor*> ref_shape;
97+
ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end());
98+
StridedMemcpyWithAxis0<T>(dev_ctx, *out_grad, ref_shape, &outputs);
10899
} else {
109-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
110-
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
111-
concat_grad_functor;
112-
concat_grad_functor(dev_ctx, *out_grad,
113-
ctx.MultiInput<framework::Tensor>("X"),
114-
static_cast<int>(axis), &outputs);
100+
math::SplitFunctor<DeviceContext, T> split_functor;
101+
split_functor(dev_ctx, *out_grad, ctx.MultiInput<framework::Tensor>("X"),
102+
static_cast<int>(axis), &outputs);
115103
}
116104
}
117105
};

paddle/fluid/operators/detection/generate_proposal_labels_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License. */
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/operators/detection/bbox_util.h"
1818
#include "paddle/fluid/operators/gather.h"
19-
#include "paddle/fluid/operators/math/concat.h"
19+
#include "paddle/fluid/operators/math/concat_and_split.h"
2020
#include "paddle/fluid/operators/math/math_function.h"
2121

2222
namespace paddle {

paddle/fluid/operators/lod_tensor_to_array_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/lod_tensor_array.h"
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/operators/detail/safe_ref.h"
20-
#include "paddle/fluid/operators/math/concat.h"
20+
#include "paddle/fluid/operators/math/concat_and_split.h"
2121
#include "paddle/fluid/platform/device_context.h"
2222
#include "paddle/fluid/platform/port.h"
2323

@@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
7979
template <typename DeviceContext>
8080
template <typename T>
8181
void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() {
82-
math::ConcatGradFunctor<DeviceContext, T> func;
82+
math::SplitFunctor<DeviceContext, T> func;
8383
func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0,
8484
&prev_functor_->outputs_);
8585
}

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
if (NOT WIN32)
2-
add_subdirectory(detail)
2+
add_subdirectory(detail)
33
endif(NOT WIN32)
44

55
function(math_library TARGET)
@@ -35,16 +35,16 @@ function(math_library TARGET)
3535
endfunction()
3636

3737
# please add new math_library in alphabetical order
38-
math_library(concat)
38+
math_library(concat_and_split)
3939
math_library(context_project DEPS im2col math_function)
4040
math_library(cross_entropy)
4141
math_library(cos_sim_functor)
4242
math_library(depthwise_conv)
4343
math_library(im2col)
4444

4545
if (NOT WIN32) # windows do not support avx functions yet.
46-
math_library(gru_compute DEPS activation_functions math_function)
47-
math_library(lstm_compute DEPS activation_functions)
46+
math_library(gru_compute DEPS activation_functions math_function)
47+
math_library(lstm_compute DEPS activation_functions)
4848
endif (NOT WIN32)
4949

5050
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
@@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function)
5858
math_library(sequence_scale)
5959
math_library(softmax DEPS math_function)
6060
if (NOT WIN32)
61-
math_library(matrix_bit_code)
61+
math_library(matrix_bit_code)
6262
endif (NOT WIN32)
6363
math_library(unpooling)
6464
math_library(vol2col)
@@ -72,7 +72,7 @@ if(WITH_GPU)
7272
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
7373
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
7474
endif()
75-
cc_test(concat_test SRCS concat_test.cc DEPS concat)
75+
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
7676
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
7777
cc_library(jit_kernel
7878
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc

paddle/fluid/operators/math/concat.cc renamed to paddle/fluid/operators/math/concat_and_split.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/math/concat.h"
15+
#include "paddle/fluid/operators/math/concat_and_split.h"
1616
#include <vector>
1717

1818
namespace paddle {
@@ -67,7 +67,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
6767
* each dimension must be the same, except the axis dimension.
6868
*/
6969
template <typename T>
70-
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
70+
class SplitFunctor<platform::CPUDeviceContext, T> {
7171
public:
7272
void operator()(const platform::CPUDeviceContext& context,
7373
const framework::Tensor& input,
@@ -111,7 +111,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
111111
};
112112
#define DEFINE_FUNCTOR(type) \
113113
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
114-
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
114+
template class SplitFunctor<platform::CPUDeviceContext, type>;
115115

116116
FOR_ALL_TYPES(DEFINE_FUNCTOR);
117117

paddle/fluid/operators/math/concat.cu renamed to paddle/fluid/operators/math/concat_and_split.cu

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License. */
1515
#include <algorithm>
1616
#include <vector>
1717
#include "paddle/fluid/framework/mixed_vector.h"
18-
#include "paddle/fluid/operators/math/concat.h"
18+
#include "paddle/fluid/operators/math/concat_and_split.h"
1919
#include "paddle/fluid/platform/cuda_primitives.h"
2020
#include "paddle/fluid/platform/float16.h"
2121

@@ -24,7 +24,7 @@ namespace operators {
2424
namespace math {
2525

2626
template <typename T>
27-
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
27+
__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
2828
const int output_rows, const int output_cols,
2929
T* output) {
3030
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
@@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
5050
}
5151

5252
template <typename T>
53-
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
53+
__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
5454
const int out_rows, const int out_cols,
5555
T* output_data) {
5656
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
@@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
6767
}
6868

6969
template <typename T>
70-
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
71-
const int in_col, const int* out_cols,
72-
int out_cols_size, T** outputs_data) {
70+
__global__ void SplitKernel(const T* input_data, const int in_row,
71+
const int in_col, const int* out_cols,
72+
int out_cols_size, T** outputs_data) {
7373
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
7474
int curr_segment = 0;
7575
int curr_offset = out_cols[0];
@@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
9494
}
9595

9696
template <typename T>
97-
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
98-
const int in_col, const int fixed_out_col,
99-
T** outputs_data) {
97+
__global__ void SplitKernel(const T* input_data, const int in_row,
98+
const int in_col, const int fixed_out_col,
99+
T** outputs_data) {
100100
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
101101
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
102102
int split = tid_x / fixed_out_col;
@@ -170,11 +170,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
170170
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
171171

172172
if (sameShape) {
173-
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
173+
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
174174
dev_ins_data, in_col, out_row, out_col, output->data<T>());
175175
} else {
176176
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
177-
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
177+
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
178178
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
179179
out_row, out_col, output->data<T>());
180180
}
@@ -189,7 +189,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
189189
* each dimension must be the same, except the axis dimension.
190190
*/
191191
template <typename T>
192-
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
192+
class SplitFunctor<platform::CUDADeviceContext, T> {
193193
public:
194194
void operator()(const platform::CUDADeviceContext& context,
195195
const framework::Tensor& input,
@@ -248,11 +248,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
248248
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
249249

250250
if (sameShape) {
251-
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
251+
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
252252
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
253253
} else {
254254
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
255-
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
255+
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
256256
input.data<T>(), in_row, in_col, dev_outs_col_data,
257257
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
258258
}
@@ -264,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
264264

265265
#define DEFINE_FUNCTOR(type) \
266266
template class ConcatFunctor<platform::CUDADeviceContext, type>; \
267-
template class ConcatGradFunctor<platform::CUDADeviceContext, type>
267+
template class SplitFunctor<platform::CUDADeviceContext, type>
268268

269269
FOR_ALL_TYPES(DEFINE_FUNCTOR);
270270

paddle/fluid/operators/math/concat.h renamed to paddle/fluid/operators/math/concat_and_split.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ConcatFunctor {
5454
* Output[1] = [[5,6]]
5555
*/
5656
template <typename DeviceContext, typename T>
57-
class ConcatGradFunctor {
57+
class SplitFunctor {
5858
public:
5959
void operator()(const DeviceContext& context, const framework::Tensor& input,
6060
const std::vector<const framework::Tensor*>& ref_inputs,

paddle/fluid/operators/math/concat_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/math/concat.h"
1615
#include <gtest/gtest.h>
1716
#include <vector>
1817
#include "paddle/fluid/framework/tensor_util.h"
18+
#include "paddle/fluid/operators/math/concat_and_split.h"
1919

2020
template <typename DeviceContext, typename Place>
2121
void testConcat() {

0 commit comments

Comments
 (0)