diff --git a/openmp/runtime/src/include/ompx.h.var b/openmp/runtime/src/include/ompx.h.var index 623f0b9c315bd..6884745f4240c 100644 --- a/openmp/runtime/src/include/ompx.h.var +++ b/openmp/runtime/src/include/ompx.h.var @@ -9,13 +9,21 @@ #ifndef __OMPX_H #define __OMPX_H -#ifdef __AMDGCN_WAVEFRONT_SIZE -#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE -#else -#define __WARP_SIZE 32 +#if (defined(__NVPTX__) || defined(__AMDGPU__)) +#include +#define __OMPX_TARGET_IS_GPU #endif typedef unsigned long uint64_t; +typedef unsigned int uint32_t; + +static inline uint32_t __warpSize(void) { +#ifdef __OMPX_TARGET_IS_GPU + return __gpu_num_lanes(); +#else + __builtin_trap(); +#endif +} #ifdef __cplusplus extern "C" { @@ -212,7 +220,7 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) { ///{ #define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \ static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta, \ - int width = __WARP_SIZE) { \ + int width = __warpSize()) { \ return ompx_shfl_down_sync_##TY(mask, var, delta, width); \ }