Skip to content

Commit 80cca70

Browse files
committed
remove warpSize usage on host side
1 parent 8ffba3f commit 80cca70

File tree

7 files changed

+40
-15
lines changed

7 files changed

+40
-15
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
183183
uint64_t block_size = 1;
184184
uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
185185

186-
// We need a block size that is a multiple of C10_WARP_SIZE in order
186+
// We need a block size that is a multiple of at::cuda::warp_size() in order
187187
// to perform block size reductions using warp shuffle instructions.
188-
// Since max_threads is also a multiple of C10_WARPS_SIZE we do not
188+
// Since max_threads is also a multiple of at::cuda::warp_size() we do not
189189
// risk creating a block size larger than the limit.
190190

191-
if (max_block_size % C10_WARP_SIZE == 0) {
191+
int warp_size = at::cuda::warp_size();
192+
if (max_block_size % warp_size == 0) {
192193
block_size = max_block_size;
193194
} else {
194-
block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
195+
block_size = (max_block_size / warp_size + 1) * warp_size;
195196
}
196197

197198
return dim3(block_size);
@@ -1107,7 +1108,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11071108
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
11081109
if constexpr (use_fast_softmax) {
11091110
dim3 block(512);
1110-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1111+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
11111112
if (dim_size % ILP == 0) {
11121113
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, scalar_t, EpilogueWithMul>
11131114
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
@@ -1117,7 +1118,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11171118
}
11181119
} else {
11191120
dim3 block = SoftMaxForward_getBlockSize(dim_size);
1120-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1121+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
11211122
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
11221123
smem_reduction_sz) / sizeof(scalar_t);
11231124

@@ -1198,7 +1199,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
11981199
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
11991200
if constexpr (use_fast_softmax) {
12001201
dim3 block(512);
1201-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1202+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
12021203
if (dim_size % ILP == 0) {
12031204
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, accscalar_t, EpilogueWithMul>
12041205
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
@@ -1208,7 +1209,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
12081209
}
12091210
} else {
12101211
dim3 block = SoftMaxForward_getBlockSize(dim_size);
1211-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1212+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
12121213
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
12131214
smem_reduction_sz) / sizeof(scalar_t);
12141215

@@ -1274,7 +1275,7 @@ void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, T
12741275
constexpr int ILP = sizeof(float4) / sizeof(output_t);
12751276
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
12761277

1277-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1278+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
12781279
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
12791280
smem_reduction_sz) / sizeof(output_t);
12801281
bool can_use_smem = static_cast<size_t>(dim_size) < max_elements_per_smem;

aten/src/ATen/native/cuda/TensorTopK.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts(
439439
warp_counts[warp] = count;
440440
}
441441
__syncthreads();
442+
#ifdef USE_ROCM
443+
CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE);
444+
#else
442445
static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
443446
"Assuming only 1 warp is needed for final reduction");
447+
#endif
444448
if (warp != 0) {
445449
return;
446450
}

aten/src/ATen/native/cuda/block_reduce.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ constexpr int kCUDABlockReduceNumThreads = 512;
1212
// of which reduces C10_WARP_SIZE elements. So, at most
1313
// C10_WARP_SIZE**2 elements can be reduced at a time.
1414
// NOTE: This is >= the max block size on current hardware anyway (1024).
15+
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
16+
// and kCUDABlockReduceMaxThreads is a host-side variable.
17+
#ifdef USE_ROCM
18+
static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size();
19+
#else
1520
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
21+
#endif
1622

1723
// Sums `val` across all threads in a warp.
1824
//

c10/macros/Macros.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
312312
#endif
313313

314314
#if defined(USE_ROCM)
315-
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
315+
// C10_WARP_SIZE is only allowed for device code.
316+
// Host code _must_ use at::cuda::warp_size()
317+
// HIP header used to define warpSize as a constexpr that was either 32 or 64
318+
// depending on the target device, and then always set it to 64 for host code.
319+
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
320+
// set it to something unreasonable to trigger obvious host code errors.
321+
#if defined(__HIP_DEVICE_COMPILE__)
322+
#if defined(__GFX9__)
323+
static constexpr int C10_WARP_SIZE = 64;
324+
#else // __GFX9__
325+
static constexpr int C10_WARP_SIZE = 32;
326+
#endif // __GFX9__
327+
#else
328+
static constexpr int C10_WARP_SIZE = 1;
329+
#endif // __HIP_DEVICE_COMPILE__
316330
#else
317331
#define C10_WARP_SIZE 32
318332
#endif

third_party/composable_kernel

Submodule composable_kernel updated 683 files

torch/csrc/distributed/c10d/CUDASymmetricMemory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ static __global__ void barrier_kernel(
471471
void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) {
472472
check_channel(channel, world_size_);
473473
c10::cuda::CUDAGuard guard(local_device_idx_);
474-
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
474+
barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
475475
reinterpret_cast<uint32_t**>(signal_pads_dev_),
476476
channel,
477477
rank_,
@@ -509,7 +509,7 @@ void CUDASymmetricMemory::put_signal(
509509
size_t timeout_ms) {
510510
check_channel(channel, world_size_);
511511
c10::cuda::CUDAGuard guard(local_device_idx_);
512-
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
512+
put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
513513
reinterpret_cast<uint32_t**>(signal_pads_dev_),
514514
dst_rank,
515515
channel,
@@ -553,7 +553,7 @@ void CUDASymmetricMemory::wait_signal(
553553
size_t timeout_ms) {
554554
check_channel(channel, world_size_);
555555
c10::cuda::CUDAGuard guard(local_device_idx_);
556-
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
556+
wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
557557
reinterpret_cast<uint32_t**>(signal_pads_dev_),
558558
src_rank,
559559
channel,

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void init_elementwise_launch_config(
114114
num_blocks = 1;
115115
num_threads = at::round_up(
116116
at::ceil_div(numel_per_split, numel_per_thread),
117-
static_cast<size_t>(C10_WARP_SIZE));
117+
static_cast<size_t>(at::cuda::warp_size()));
118118
} else {
119119
num_blocks = std::min(
120120
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),

0 commit comments

Comments
 (0)