@@ -183,15 +183,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
183183 uint64_t block_size = 1 ;
184184 uint64_t max_block_size = std::min (dim_size, static_cast <uint64_t >(max_threads));
185185
186- // We need a block size that is a multiple of C10_WARP_SIZE in order
186+ // We need a block size that is a multiple of at::cuda::warp_size() in order
187187 // to perform block size reductions using warp shuffle instructions.
188- // Since max_threads is also a multiple of C10_WARPS_SIZE we do not
188+ // Since max_threads is also a multiple of at::cuda::warp_size() we do not
189189 // risk creating a block size larger than the limit.
190190
191- if (max_block_size % C10_WARP_SIZE == 0 ) {
191+ int warp_size = at::cuda::warp_size ();
192+ if (max_block_size % warp_size == 0 ) {
192193 block_size = max_block_size;
193194 } else {
194- block_size = (max_block_size / C10_WARP_SIZE + 1 ) * C10_WARP_SIZE ;
195+ block_size = (max_block_size / warp_size + 1 ) * warp_size ;
195196 }
196197
197198 return dim3 (block_size);
@@ -1107,7 +1108,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11071108 constexpr int ILP = sizeof (float4 ) / sizeof (scalar_t );
11081109 if constexpr (use_fast_softmax) {
11091110 dim3 block (512 );
1110- size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof (accscalar_t );
1111+ size_t smem_reduction_sz = block.x / at::cuda::warp_size () * sizeof (accscalar_t );
11111112 if (dim_size % ILP == 0 ) {
11121113 cunn_SoftMaxForwardGmem<ILP, scalar_t , accscalar_t , scalar_t , EpilogueWithMul>
11131114 <<<grid, block, smem_reduction_sz, stream>>> (output_ptr, input_ptr, dim_size);
@@ -1117,7 +1118,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11171118 }
11181119 } else {
11191120 dim3 block = SoftMaxForward_getBlockSize (dim_size);
1120- size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof (accscalar_t );
1121+ size_t smem_reduction_sz = block.x / at::cuda::warp_size () * sizeof (accscalar_t );
11211122 auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock -
11221123 smem_reduction_sz) / sizeof (scalar_t );
11231124
@@ -1198,7 +1199,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11981199 constexpr int ILP = sizeof (float4 ) / sizeof (scalar_t );
11991200 if constexpr (use_fast_softmax) {
12001201 dim3 block (512 );
1201- size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof (accscalar_t );
1202+ size_t smem_reduction_sz = block.x / at::cuda::warp_size () * sizeof (accscalar_t );
12021203 if (dim_size % ILP == 0 ) {
12031204 cunn_SoftMaxForwardGmem<ILP, scalar_t , accscalar_t , accscalar_t , EpilogueWithMul>
12041205 <<<grid, block, smem_reduction_sz, stream>>> (output_ptr, input_ptr, dim_size);
@@ -1208,7 +1209,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
12081209 }
12091210 } else {
12101211 dim3 block = SoftMaxForward_getBlockSize (dim_size);
1211- size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof (accscalar_t );
1212+ size_t smem_reduction_sz = block.x / at::cuda::warp_size () * sizeof (accscalar_t );
12121213 auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock -
12131214 smem_reduction_sz) / sizeof (scalar_t );
12141215
@@ -1274,7 +1275,7 @@ void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, T
12741275 constexpr int ILP = sizeof (float4 ) / sizeof (output_t );
12751276 dim3 block = SoftMax_getBlockSize (ILP, dim_size);
12761277
1277- size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof (accscalar_t );
1278+ size_t smem_reduction_sz = block.x / at::cuda::warp_size () * sizeof (accscalar_t );
12781279 auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock -
12791280 smem_reduction_sz) / sizeof (output_t );
12801281 bool can_use_smem = static_cast <size_t >(dim_size) < max_elements_per_smem;
0 commit comments