@@ -1799,28 +1799,29 @@ DEFN_ARITH_OPERATIONS(half)
17991799#endif // defined(cl_khr_fp16)
18001800
18011801#define DEFN_WORK_GROUP_REDUCE (type , op , identity , X ) \
1802- { \
1803- GET_MEMPOOL_PTR(data, type, true, 0) \
1804- uint lid = __spirv_BuiltInLocalInvocationIndex(); \
1805- uint lsize = __spirv_WorkgroupSize(); \
1806- data[lid] = X; \
1807- \
1808- uint i = lsize / 2; \
1809- while(i > 0) \
1810- { \
1802+ { \
1803+ GET_MEMPOOL_PTR(data, type, true, 0) \
1804+ uint lid = __spirv_BuiltInLocalInvocationIndex(); \
1805+ uint lsize = __spirv_WorkgroupSize(); \
1806+ data[lid] = X; \
1807+ __builtin_spirv_OpControlBarrier_i32_i32_i32(Execution, 0, AcquireRelease | WorkgroupMemory); \
1808+ uint mask = 1 << ( ((8 * sizeof(uint)) - __builtin_spirv_OpenCL_clz_i32(lsize - 1)) - 1) ; \
1809+ while( mask > 0 ) \
1810+ { \
1811+ uint c = lid ^ mask; \
1812+ type other = ( c < lsize ) ? data[ c ] : identity; \
1813+ X = op( other, X ); \
18111814 __builtin_spirv_OpControlBarrier_i32_i32_i32(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
1812- if ((lid < i) && (lid + i < lsize)) \
1813- { \
1814- X = op(X, data[lid + i]); \
1815- data[lid] = X; \
1816- } \
1817- i >>= 1; \
1818- } \
1819- lid >>= 15;\
1815+ data[lid] = X; \
1816+ __builtin_spirv_OpControlBarrier_i32_i32_i32(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
1817+ mask >>= 1; \
1818+ } \
1819+ type ret = data[0]; \
18201820 __builtin_spirv_OpControlBarrier_i32_i32_i32(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
1821- return data[0+lid]; \
1821+ return ret; \
18221822}
18231823
1824+
18241825#define DEFN_WORK_GROUP_SCAN_INCL (type , op , identity , X ) \
18251826{ \
18261827 GET_MEMPOOL_PTR(data, type, true, 0) \
@@ -1843,6 +1844,7 @@ DEFN_ARITH_OPERATIONS(half)
18431844 return X; \
18441845}
18451846
1847+
18461848#define DEFN_WORK_GROUP_SCAN_EXCL (type , op , identity , X ) \
18471849{ \
18481850 GET_MEMPOOL_PTR(data, type, true, 1) \
0 commit comments