File tree Expand file tree Collapse file tree 3 files changed +32
-7
lines changed Expand file tree Collapse file tree 3 files changed +32
-7
lines changed Original file line number Diff line number Diff line change @@ -33,7 +33,12 @@ namespace at::native {
3333namespace {
3434
3535constexpr int kCUDANumThreads = 256 ;
36+ #ifdef USE_ROCM
37+ // C10_WARP_SIZE is not constexpr for host code.
38+ #define kWarpSize C10_WARP_SIZE
39+ #else
3640constexpr unsigned int kWarpSize = C10_WARP_SIZE;
41+ #endif
3742constexpr int vec_size = 4 ; // we could make it dependent on dtype, but that would lead to different results between float and low-p types
3843
3944// aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
Original file line number Diff line number Diff line change @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel(
242242// `if constexpr` when CUDA codes will be compiled under C++-17, see
243243// gh-56055 for blockers.
244244template <typename Dtype>
245+ #ifdef USE_ROCM
246+ C10_LAUNCH_BOUNDS_1 (C10_WARP_SIZE_STATIC*4 )
247+ #else
245248C10_LAUNCH_BOUNDS_1 (C10_WARP_SIZE*4 )
249+ #endif
246250__global__ void coalesceValuesKernel (
247251 int64_t *segment_offsets, int64_t *value_indices,
248252 bool *values, bool *newValues,
Original file line number Diff line number Diff line change @@ -318,16 +318,32 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
318318// depending on the target device, and then always set it to 64 for host code.
319319// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
320320// set it to something unreasonable to trigger obvious host code errors.
321- #if defined(__HIP_DEVICE_COMPILE__)
321+
322+ namespace at ::cuda {
323+ TORCH_CUDA_CPP_API int warp_size ();
324+ }
325+ #ifdef __HIPCC__
326+ static inline int __host__ C10_WARP_SIZE_INTERNAL () {
327+ return at::cuda::warp_size ();
328+ }
329+
330+ static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL () {
322331#if defined(__GFX9__)
323- static constexpr int C10_WARP_SIZE = 64 ;
332+ return 64 ;
324333#else // __GFX9__
325- static constexpr int C10_WARP_SIZE = 32 ;
334+ return 32 ;
326335#endif // __GFX9__
327- #else
328- static constexpr int C10_WARP_SIZE = 1 ;
329- #endif // __HIP_DEVICE_COMPILE__
330- #else
336+ }
337+ #else // __HIPCC__
338+ inline int C10_WARP_SIZE_INTERNAL () {
339+ return at::cuda::warp_size ();
340+ }
341+ #endif // __HIPCC__
342+
343+ #define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL())
344+ #define C10_WARP_SIZE_STATIC 64
345+
346+ #else // defined(USE_ROCM)
331347#define C10_WARP_SIZE 32
332348#endif
333349
You can’t perform that action at this time.
0 commit comments