Skip to content

Commit 57af117

Browse files
authored
[PHI] Refine softmax kernel support for big tensor (#73460)
1 parent 11afda0 commit 57af117

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

paddle/phi/kernels/funcs/axis_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static inline T SizeToAxis(const int axis, DDim dims) {
3636
}
3737

3838
template <typename T = int>
39-
static inline int SizeFromAxis(const int axis, DDim dims) {
39+
static inline T SizeFromAxis(const int axis, DDim dims) {
4040
T size = 1;
4141
for (int i = axis; i < dims.size(); i++) {
4242
size *= dims[i];
@@ -45,7 +45,7 @@ static inline int SizeFromAxis(const int axis, DDim dims) {
4545
}
4646

4747
template <typename T = int>
48-
static inline int SizeOutAxis(const int axis, DDim dims) {
48+
static inline T SizeOutAxis(const int axis, DDim dims) {
4949
T size = 1;
5050
for (int i = axis + 1; i < dims.size(); i++) {
5151
size *= dims[i];

paddle/phi/kernels/gpudnn/softmax_gpudnn.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ class VecT2<phi::dtype::bfloat16> {
8282
using Type = int;
8383
};
8484

85-
static inline int Log2Ceil(int value) {
85+
static inline int Log2Ceil(int64_t value) {
8686
int log2_value = 0;
87-
while ((1 << log2_value) < value) ++log2_value;
87+
while ((int64_t(1) << log2_value) < value) ++log2_value;
8888
return log2_value;
8989
}
9090

@@ -836,37 +836,42 @@ void SwitchWarpSoftmaxBackward(const IndexType blocks,
836836
* Better performance when axis != -1
837837
*/
838838

839-
static void GetGridDim(
840-
int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
839+
static void GetGridDim(int64_t high_dim,
840+
int64_t low_dim,
841+
const dim3& block,
842+
dim3* grid) {
841843
int device_id = phi::backends::gpu::GetCurrentDeviceId();
842844
int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
843845
int max_threads_per_mp =
844846
phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
845847
int max_threads = max_threads_per_mp * max_mp;
846848
int num_threads = block.x * block.y;
847-
int max_num_blocks = max_threads / num_threads;
849+
int64_t max_num_blocks = max_threads / num_threads;
848850

849-
int grid_x = (low_dim + block.x - 1) / block.x;
851+
int64_t grid_x = (low_dim + block.x - 1) / block.x;
850852
grid_x = std::min(grid_x, max_num_blocks);
851-
int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
853+
int64_t grid_y = (max_num_blocks + grid_x - 1) / grid_x;
852854
grid_y = std::min(grid_y, high_dim);
853855
grid->x = grid_x;
854856
grid->y = grid_y;
855857
}
856858

857-
static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
859+
static void GetBlockDim(int64_t mid_dim, int64_t low_dim, dim3* block) {
858860
constexpr int max_num_threads = 1024;
859-
int block_x = 1 << Log2Ceil(low_dim);
860-
int block_y = 1 << Log2Ceil(mid_dim);
861-
block->x = std::min(block_x, 32);
862-
block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
863-
block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
861+
int64_t block_x = int64_t(1) << Log2Ceil(low_dim);
862+
int64_t block_y = int64_t(1) << Log2Ceil(mid_dim);
863+
block->x = std::min<int64_t>(block_x, 32);
864+
block->y = std::min<int64_t>(block_y, max_num_threads / block->x);
865+
block->x = std::min<int64_t>(block_x, max_num_threads / block->y);
864866
}
865867

866-
static void GetLaunchConfig(
867-
int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
868+
static void GetLaunchConfig(int64_t high_dim,
869+
int64_t mid_dim,
870+
int64_t low_dim,
871+
dim3* grid,
872+
dim3* block) {
868873
GetBlockDim(mid_dim, low_dim, block);
869-
GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
874+
GetGridDim(high_dim, low_dim, *block, grid);
870875
}
871876

872877
template <typename T,

0 commit comments

Comments
 (0)