@@ -2239,7 +2239,7 @@ DEFN_ARITH_OPERATIONS(double)
22392239DEFN_ARITH_OPERATIONS (half )
22402240#endif // defined(cl_khr_fp16)
22412241
2242- #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op ) \
2242+ #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op , identity ) \
22432243type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X) \
22442244{ \
22452245 type sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, X); \
@@ -2248,19 +2248,41 @@ type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X)
22482248 uint num_sg = SPIRV_BUILTIN_NO_OP(BuiltInNumSubgroups, , )(); \
22492249 uint sg_lid = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupLocalInvocationId, , )(); \
22502250 uint sg_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupSize, , )(); \
2251+ uint sg_max_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupMaxSize, , )(); \
22512252 \
2252- if (sg_lid == sg_size - 1 ) { \
2253+ if (sg_lid == 0 ) { \
22532254 scratch[sg_id] = sg_x; \
22542255 } \
22552256 SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
22562257 \
2257- type sg_aggregate = scratch[0]; \
2258- for (int s = 1; s < num_sg; ++s) { \
2259- sg_aggregate = op(sg_aggregate, scratch[s]); \
2258+ uint values_num = num_sg; \
2259+ while(values_num > sg_max_size) { \
2260+ uint max_id = ((values_num + sg_max_size - 1) / sg_max_size) * sg_max_size; \
2261+ uint global_id = sg_id * sg_max_size + sg_lid; \
2262+ if (global_id < max_id) { \
2263+ type value = global_id < values_num ? scratch[sg_id * sg_max_size + sg_lid] : identity; \
2264+ sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value);\
2265+ if (sg_lid == 0) { \
2266+ scratch[sg_id] = sg_x; \
2267+ } \
2268+ } \
2269+ values_num = max_id / sg_max_size; \
2270+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
22602271 } \
22612272 \
2273+ type result; \
2274+ if (values_num > sg_size) { \
2275+ type sg_aggregate = scratch[0]; \
2276+ for (int s = 1; s < values_num; ++s) { \
2277+ sg_aggregate = op(sg_aggregate, scratch[s]); \
2278+ } \
2279+ result = sg_aggregate; \
2280+ } else { \
2281+ type value = sg_lid < values_num ? scratch[sg_lid] : identity; \
2282+ result = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value); \
2283+ } \
22622284 SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2263- return sg_aggregate; \
2285+ return result; \
22642286}
22652287
22662288
@@ -2463,7 +2485,7 @@ DEFN_SUB_GROUP_REDUCE(func, type_abbr, type, op, identity, signed_cast)
24632485DEFN_SUB_GROUP_SCAN_INCL(func, type_abbr, type, op, identity) \
24642486DEFN_SUB_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
24652487 \
2466- DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op) \
2488+ DEFN_WORK_GROUP_REDUCE(func, type_abbr, type, op, identity) \
24672489DEFN_WORK_GROUP_SCAN_INCL(func, type_abbr, type, op) \
24682490DEFN_WORK_GROUP_SCAN_EXCL(func, type_abbr, type, op, identity) \
24692491 \
0 commit comments