@@ -82,9 +82,9 @@ class VecT2<phi::dtype::bfloat16> {
82
82
using Type = int ;
83
83
};
84
84
85
- static inline int Log2Ceil (int value) {
85
+ static inline int Log2Ceil (int64_t value) {
86
86
int log2_value = 0 ;
87
- while ((1 << log2_value) < value) ++log2_value;
87
+ while ((int64_t ( 1 ) << log2_value) < value) ++log2_value;
88
88
return log2_value;
89
89
}
90
90
@@ -836,37 +836,42 @@ void SwitchWarpSoftmaxBackward(const IndexType blocks,
836
836
* Better performance when axis != -1
837
837
*/
838
838
839
- static void GetGridDim (
840
- int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
839
+ static void GetGridDim (int64_t high_dim,
840
+ int64_t low_dim,
841
+ const dim3& block,
842
+ dim3* grid) {
841
843
int device_id = phi::backends::gpu::GetCurrentDeviceId ();
842
844
int max_mp = phi::backends::gpu::GetGPUMultiProcessors (device_id);
843
845
int max_threads_per_mp =
844
846
phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor (device_id);
845
847
int max_threads = max_threads_per_mp * max_mp;
846
848
int num_threads = block.x * block.y ;
847
- int max_num_blocks = max_threads / num_threads;
849
+ int64_t max_num_blocks = max_threads / num_threads;
848
850
849
- int grid_x = (low_dim + block.x - 1 ) / block.x ;
851
+ int64_t grid_x = (low_dim + block.x - 1 ) / block.x ;
850
852
grid_x = std::min (grid_x, max_num_blocks);
851
- int grid_y = (max_num_blocks + grid_x - 1 ) / grid_x;
853
+ int64_t grid_y = (max_num_blocks + grid_x - 1 ) / grid_x;
852
854
grid_y = std::min (grid_y, high_dim);
853
855
grid->x = grid_x;
854
856
grid->y = grid_y;
855
857
}
856
858
857
- static void GetBlockDim (int mid_dim, int low_dim, dim3* block) {
859
+ static void GetBlockDim (int64_t mid_dim, int64_t low_dim, dim3* block) {
858
860
constexpr int max_num_threads = 1024 ;
859
- int block_x = 1 << Log2Ceil (low_dim);
860
- int block_y = 1 << Log2Ceil (mid_dim);
861
- block->x = std::min (block_x, 32 );
862
- block->y = std::min (block_y, static_cast < int >( max_num_threads / block->x ) );
863
- block->x = std::min (block_x, static_cast < int >( max_num_threads / block->y ) );
861
+ int64_t block_x = int64_t ( 1 ) << Log2Ceil (low_dim);
862
+ int64_t block_y = int64_t ( 1 ) << Log2Ceil (mid_dim);
863
+ block->x = std::min< int64_t > (block_x, 32 );
864
+ block->y = std::min< int64_t > (block_y, max_num_threads / block->x );
865
+ block->x = std::min< int64_t > (block_x, max_num_threads / block->y );
864
866
}
865
867
866
- static void GetLaunchConfig (
867
- int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
868
+ static void GetLaunchConfig (int64_t high_dim,
869
+ int64_t mid_dim,
870
+ int64_t low_dim,
871
+ dim3* grid,
872
+ dim3* block) {
868
873
GetBlockDim (mid_dim, low_dim, block);
869
- GetGridDim (high_dim, mid_dim, low_dim, *block, grid);
874
+ GetGridDim (high_dim, low_dim, *block, grid);
870
875
}
871
876
872
877
template <typename T,
0 commit comments