Skip to content

Commit 4edcb2c

Browse files
jeffdailyjithunnair-amd
authored andcommitted
remove warpSize usage on host side
1 parent 67e1950 commit 4edcb2c

File tree

7 files changed

+37
-12
lines changed

7 files changed

+37
-12
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
177177
uint64_t block_size = 1;
178178
uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
179179

180-
// We need a block size that is a multiple of C10_WARP_SIZE in order
180+
// We need a block size that is a multiple of at::cuda::warp_size() in order
181181
// to perform block size reductions using warp shuffle instructions.
182182
// Since max_threads is also a multiple of C10_WARPS_SIZE we do not
183183
// risk creating a block size larger than the limit.
184184

185-
if (max_block_size % C10_WARP_SIZE == 0) {
185+
int warp_size = at::cuda::warp_size();
186+
if (max_block_size % warp_size == 0) {
186187
block_size = max_block_size;
187188
} else {
188-
block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
189+
block_size = (max_block_size / warp_size + 1) * warp_size;
189190
}
190191

191192
return dim3(block_size);
@@ -978,7 +979,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
978979
} else {
979980
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
980981
dim3 block = SoftMaxForward_getBlockSize(dim_size);
981-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
982+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
982983
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
983984
smem_reduction_sz) / sizeof(scalar_t);
984985

@@ -1057,7 +1058,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
10571058
} else {
10581059
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
10591060
dim3 block = SoftMaxForward_getBlockSize(dim_size);
1060-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1061+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
10611062
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
10621063
smem_reduction_sz) / sizeof(scalar_t);
10631064

@@ -1122,7 +1123,7 @@ void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, T
11221123
constexpr int ILP = sizeof(float4) / sizeof(output_t);
11231124
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
11241125

1125-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
1126+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
11261127
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
11271128
smem_reduction_sz) / sizeof(output_t);
11281129
bool can_use_smem = static_cast<size_t>(dim_size) < max_elements_per_smem;

aten/src/ATen/native/cuda/TensorTopK.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts(
439439
warp_counts[warp] = count;
440440
}
441441
__syncthreads();
442+
#ifdef USE_ROCM
443+
CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE);
444+
#else
442445
static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
443446
"Assuming only 1 warp is needed for final reduction");
447+
#endif
444448
if (warp != 0) {
445449
return;
446450
}

aten/src/ATen/native/cuda/block_reduce.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ constexpr int kCUDABlockReduceNumThreads = 512;
1212
// of which reduces C10_WARP_SIZE elements. So, at most
1313
// C10_WARP_SIZE**2 elements can be reduced at a time.
1414
// NOTE: This is >= the max block size on current hardware anyway (1024).
15+
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
16+
// and kCUDABlockReduceMaxThreads is a host-side variable.
17+
#ifdef USE_ROCM
18+
static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size();
19+
#else
1520
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
21+
#endif
1622

1723
// Sums `val` across all threads in a warp.
1824
//

c10/macros/Macros.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
312312
#endif
313313

314314
#if defined(USE_ROCM)
315-
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
315+
// C10_WARP_SIZE is only allowed for device code.
316+
// Host code _must_ use at::cuda::warp_size()
317+
// HIP header used to define warpSize as a constexpr that was either 32 or 64
318+
// depending on the target device, and then always set it to 64 for host code.
319+
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
320+
// set it to something unreasonable to trigger obvious host code errors.
321+
#if defined(__HIP_DEVICE_COMPILE__)
322+
#if defined(__GFX9__)
323+
static constexpr int C10_WARP_SIZE = 64;
324+
#else // __GFX9__
325+
static constexpr int C10_WARP_SIZE = 32;
326+
#endif // __GFX9__
327+
#else
328+
static constexpr int C10_WARP_SIZE = 1;
329+
#endif // __HIP_DEVICE_COMPILE__
316330
#else
317331
#define C10_WARP_SIZE 32
318332
#endif

third_party/composable_kernel

Submodule composable_kernel updated 683 files

torch/csrc/distributed/c10d/CUDASymmetricMemory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ static __global__ void barrier_kernel(
471471
void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) {
472472
check_channel(channel, world_size_);
473473
c10::cuda::CUDAGuard guard(local_device_idx_);
474-
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
474+
barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
475475
reinterpret_cast<uint32_t**>(signal_pads_dev_),
476476
channel,
477477
rank_,
@@ -509,7 +509,7 @@ void CUDASymmetricMemory::put_signal(
509509
size_t timeout_ms) {
510510
check_channel(channel, world_size_);
511511
c10::cuda::CUDAGuard guard(local_device_idx_);
512-
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
512+
put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
513513
reinterpret_cast<uint32_t**>(signal_pads_dev_),
514514
dst_rank,
515515
channel,
@@ -553,7 +553,7 @@ void CUDASymmetricMemory::wait_signal(
553553
size_t timeout_ms) {
554554
check_channel(channel, world_size_);
555555
c10::cuda::CUDAGuard guard(local_device_idx_);
556-
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
556+
wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
557557
reinterpret_cast<uint32_t**>(signal_pads_dev_),
558558
src_rank,
559559
channel,

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void init_elementwise_launch_config(
104104
num_blocks = 1;
105105
num_threads = at::round_up(
106106
at::ceil_div(numel_per_split, numel_per_thread),
107-
static_cast<size_t>(C10_WARP_SIZE));
107+
static_cast<size_t>(at::cuda::warp_size()));
108108
} else {
109109
num_blocks = std::min(
110110
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),

0 commit comments

Comments
 (0)