File tree Expand file tree Collapse file tree 3 files changed +41
-2
lines changed Expand file tree Collapse file tree 3 files changed +41
-2
lines changed Original file line number Diff line number Diff line change @@ -33,7 +33,12 @@ namespace at::native {
3333namespace {
3434
3535constexpr int kCUDANumThreads = 256 ;
36+ #ifdef USE_ROCM
37+ // C10_WARP_SIZE is not constexpr for host code.
38+ #define kWarpSize C10_WARP_SIZE
39+ #else
3640constexpr unsigned int kWarpSize = C10_WARP_SIZE;
41+ #endif
3742constexpr int vec_size = 4 ; // we could make it dependent on dtype, but that would lead to different results between float and low-p types
3843
3944// aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
Original file line number Diff line number Diff line change @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel(
242242// `if constexpr` when CUDA codes will be compiled under C++-17, see
243243// gh-56055 for blockers.
244244template <typename Dtype>
245+ #ifdef USE_ROCM
246+ C10_LAUNCH_BOUNDS_1 (C10_WARP_SIZE_STATIC*4 )
247+ #else
245248C10_LAUNCH_BOUNDS_1 (C10_WARP_SIZE*4 )
249+ #endif
246250__global__ void coalesceValuesKernel (
247251 int64_t *segment_offsets, int64_t *value_indices,
248252 bool *values, bool *newValues,
Original file line number Diff line number Diff line change @@ -312,8 +312,38 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
312312#endif
313313
314314#if defined(USE_ROCM)
315- #define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
316- #else
315+ // C10_WARP_SIZE is only allowed for device code.
316+ // Host code _must_ use at::cuda::warp_size()
317+ // HIP header used to define warpSize as a constexpr that was either 32 or 64
318+ // depending on the target device, and then always set it to 64 for host code.
319+ // Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
320+ // set it to something unreasonable to trigger obvious host code errors.
321+
322+ namespace at ::cuda {
323+ TORCH_CUDA_CPP_API int warp_size ();
324+ }
325+ #ifdef __HIPCC__
326+ static inline int __host__ C10_WARP_SIZE_INTERNAL () {
327+ return at::cuda::warp_size ();
328+ }
329+
330+ static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL () {
331+ #if defined(__GFX9__)
332+ return 64 ;
333+ #else // __GFX9__
334+ return 32 ;
335+ #endif // __GFX9__
336+ }
337+ #else // __HIPCC__
338+ inline int C10_WARP_SIZE_INTERNAL () {
339+ return at::cuda::warp_size ();
340+ }
341+ #endif // __HIPCC__
342+
343+ #define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL())
344+ #define C10_WARP_SIZE_STATIC 64
345+
346+ #else // defined(USE_ROCM)
317347#define C10_WARP_SIZE 32
318348#endif
319349
You can’t perform that action at this time.
0 commit comments