Skip to content

Commit 8c5e432

Browse files
author
zhangkaihuo
authored
[cherry-pick] Optimize sparse kernel and fix some bug (#50118)
cherry-pick some PR about optimize sparse kernel and fix some bug: #47736 #47703 #47604 #46679 #48439 #49009 #49734
1 parent e32ff65 commit 8c5e432

File tree

12 files changed

+1322
-197
lines changed

12 files changed

+1322
-197
lines changed

cmake/external/cutlass.cmake

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
include(ExternalProject)
16+
17+
set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)
18+
19+
set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
20+
set(CUTLASS_TAG v2.9.1)
21+
22+
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
23+
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
24+
include_directories(
25+
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")
26+
27+
add_definitions("-DPADDLE_WITH_CUTLASS")
28+
29+
ExternalProject_Add(
30+
extern_cutlass
31+
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
32+
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
33+
GIT_TAG "${CUTLASS_TAG}"
34+
PREFIX ${CUTLASS_PREFIX_DIR}
35+
UPDATE_COMMAND ""
36+
CONFIGURE_COMMAND ""
37+
BUILD_COMMAND ""
38+
INSTALL_COMMAND ""
39+
TEST_COMMAND "")
40+
41+
add_library(cutlass INTERFACE)
42+
43+
add_dependencies(cutlass extern_cutlass)

cmake/third_party.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,4 +492,14 @@ if(WITH_CUSPARSELT)
492492
list(APPEND third_party_deps extern_cusparselt)
493493
endif()
494494

495+
if(WITH_GPU
496+
AND NOT WITH_ARM
497+
AND NOT WIN32
498+
AND NOT APPLE)
499+
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
500+
include(external/cutlass) # download, build, install cusparselt
501+
list(APPEND third_party_deps extern_cutlass)
502+
endif()
503+
endif()
504+
495505
add_custom_target(third_party ALL DEPENDS ${third_party_deps})

paddle/phi/kernels/funcs/norm_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ limitations under the License. */
1818

1919
namespace phi {
2020
namespace funcs {
21+
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
22+
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
23+
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
24+
2125
inline void ExtractNCWHD(const phi::DDim &dims,
2226
const DataLayout &data_layout,
2327
int *N,

paddle/phi/kernels/funcs/sparse/utils.cu.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
2626
}
2727
}
2828

29+
inline __device__ bool SetBits(const int value, int* ptr) {
30+
const int index = value >> 5;
31+
const int mask = 1 << (value & 31);
32+
const int old = atomicOr(ptr + index, mask);
33+
return (mask & old) != 0;
34+
}
35+
36+
inline __device__ bool TestBits(const int value, const int* ptr) {
37+
const int index = value >> 5;
38+
const int mask = 1 << (value & 31);
39+
return (mask & ptr[index]) != 0;
40+
}
41+
2942
} // namespace sparse
3043
} // namespace funcs
3144
} // namespace phi

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx,
852852
// ctx.GetPlace()),
853853
// epsilon, saved_mean_data, saved_var_data));
854854
#else
855-
// CUDNN only support small batch size
856-
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
857-
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
858-
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
859-
const bool use_native_kernel =
860-
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
861-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
862-
if (use_native_kernel) {
863-
if (x_dims.size() == 2) {
855+
}
856+
// CUDNN only support small batch size
857+
bool use_native_nhwc =
858+
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
859+
: false;
860+
const bool use_native_kernel =
861+
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
862+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
863+
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
864+
if (use_native_kernel || use_native_nhwc) {
865+
if (x_dims.size() == 2 || use_native_nhwc) {
864866
dim3 block;
865867
dim3 grid;
866868
const int block_size = 512;

paddle/phi/kernels/gpu/batch_norm_kernel.cu

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
7272
}
7373
}
7474

75+
template <typename T>
76+
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
77+
const double epsilon,
78+
const int C,
79+
BatchNormParamType<T> *inv_variance) {
80+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
81+
if (tid < C) {
82+
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
83+
}
84+
}
85+
86+
template <typename T, phi::DataLayout layout>
87+
static __global__ void BN1DForwardInference(
88+
const T *x,
89+
const BatchNormParamType<T> *mean,
90+
const BatchNormParamType<T> *inv_variance,
91+
const BatchNormParamType<T> *scale,
92+
const BatchNormParamType<T> *bias,
93+
const int C,
94+
const int N,
95+
const int HxW,
96+
const double epsilon,
97+
T *y) {
98+
int gid = blockIdx.x * blockDim.x + threadIdx.x;
99+
int stride = blockDim.x * gridDim.x;
100+
int num = N * C * HxW;
101+
for (int i = gid; i < num; i += stride) {
102+
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
103+
BatchNormParamType<T> x_sub_mean =
104+
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
105+
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
106+
}
107+
}
108+
75109
template <typename T, int BlockDim, phi::DataLayout layout>
76110
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
77111
const T *x,
@@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,
691725

692726
auto handle = ctx.cudnn_handle();
693727

694-
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
695-
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
696-
697728
// Now, depending on whether we are running test or not, we have two paths.
698729
// It is training mode when it's not reference AND not using pre-trained
699730
// model.
@@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
797828
// epsilon));
798829
#else
799830
const bool use_native_kernel =
800-
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
801-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
831+
(x_dims.size() == 2 ||
832+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
802833
if (use_native_kernel) {
803834
const int block_size = 256;
804835
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
@@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
816847
epsilon,
817848
transformed_y.template data<T>());
818849
} else {
819-
BNForwardInference<T, DataLayout::kNHWC>
820-
<<<grid_size, block_size, 0, ctx.stream()>>>(
821-
transformed_x.template data<T>(),
822-
est_mean->template data<BatchNormParamType<T>>(),
823-
est_var->template data<BatchNormParamType<T>>(),
824-
scale.template data<BatchNormParamType<T>>(),
825-
bias.template data<BatchNormParamType<T>>(),
826-
C,
827-
N,
828-
H * W * D,
829-
epsilon,
830-
transformed_y.template data<T>());
850+
if (x_dims.size() == 2) {
851+
DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
852+
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
853+
const int threads = 512 > C ? C : 512;
854+
const int blocks = (C + 511) / 512;
855+
InverseVariance<T><<<blocks, threads>>>(
856+
est_var->template data<BatchNormParamType<T>>(),
857+
epsilon,
858+
C,
859+
inv_var_ptr);
860+
BN1DForwardInference<T, DataLayout::kNHWC>
861+
<<<grid_size, block_size, 0, ctx.stream()>>>(
862+
transformed_x.template data<T>(),
863+
est_mean->template data<BatchNormParamType<T>>(),
864+
// est_var->template data<BatchNormParamType<T>>(),
865+
inv_var_ptr,
866+
scale.template data<BatchNormParamType<T>>(),
867+
bias.template data<BatchNormParamType<T>>(),
868+
C,
869+
N,
870+
H * W * D,
871+
epsilon,
872+
transformed_y.template data<T>());
873+
} else {
874+
BNForwardInference<T, DataLayout::kNHWC>
875+
<<<grid_size, block_size, 0, ctx.stream()>>>(
876+
transformed_x.template data<T>(),
877+
est_mean->template data<BatchNormParamType<T>>(),
878+
est_var->template data<BatchNormParamType<T>>(),
879+
scale.template data<BatchNormParamType<T>>(),
880+
bias.template data<BatchNormParamType<T>>(),
881+
C,
882+
N,
883+
H * W * D,
884+
epsilon,
885+
transformed_y.template data<T>());
886+
}
831887
}
832888
} else {
833889
PADDLE_ENFORCE_GPU_SUCCESS(
@@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
9491005
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
9501006
const bool use_native_kernel =
9511007
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
952-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
1008+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
9531009
if (use_native_kernel) {
9541010
dim3 block;
9551011
dim3 grid;

0 commit comments

Comments
 (0)