Skip to content

Commit d83ecb5

Browse files
committed
Added subgroup support.
1 parent 0df27ff commit d83ecb5

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

src/kernels/level1/xdot.opencl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
// literal). Comment-out this line for syntax-highlighting when developing.
1515
R"(
1616

17-
#ifdef cl_khr_work_group_uniform_arithmetic
17+
#if defined(cl_khr_work_group_uniform_arithmetic)
1818
#pragma OPENCL EXTENSION cl_khr_work_group_uniform_arithmetic : enable
19+
#elif defined(cl_khr_subgroups)
20+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
1921
#endif
2022

2123
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
@@ -66,12 +68,22 @@ void Xdot(const int n,
6668
output[wgid] = result;
6769
}
6870
#else
69-
for (int s=WGS1/2; s>0; s=s>>1) {
70-
if (lid < s) {
71-
Add(lm[lid], lm[lid], lm[lid + s]);
72-
}
71+
#if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups)
72+
lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]);
7373
barrier(CLK_LOCAL_MEM_FENCE);
74-
}
74+
for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) {
75+
if (lid < s) {
76+
Add(lm[lid], lm[lid], lm[lid + s]);
77+
}
78+
}
79+
#else
80+
for (int s=WGS1/2; s>0; s=s>>1) {
81+
if (lid < s) {
82+
Add(lm[lid], lm[lid], lm[lid + s]);
83+
}
84+
barrier(CLK_LOCAL_MEM_FENCE);
85+
}
86+
#endif
7587

7688
if (lid == 0) {
7789
output[wgid] = lm[0];
@@ -97,23 +109,33 @@ void XdotEpilogue(const __global real* restrict input,
97109
Add(lm[lid], input[lid], input[lid + WGS2]);
98110
barrier(CLK_LOCAL_MEM_FENCE);
99111

100-
// Performs reduction in local memory and stores the per work group result
112+
// Performs reduction in local memory and stores final result
101113
#if defined(cl_khr_work_group_uniform_arithmetic) || defined(__opencl_c_work_group_collective_functions)
102114
real result = work_group_reduce_add(lm[lid])
103115

104116
if (lid == 0) {
105117
dot[dot_offset] = result;
106118
}
107119
#else
108-
for (int s=WGS1/2; s>0; s=s>>1) {
109-
if (lid < s) {
110-
Add(lm[lid], lm[lid], lm[lid + s]);
111-
}
120+
#if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups)
121+
lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]);
112122
barrier(CLK_LOCAL_MEM_FENCE);
113-
}
123+
for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) {
124+
if (lid < s) {
125+
Add(lm[lid], lm[lid], lm[lid + s]);
126+
}
127+
}
128+
#else
129+
for (int s=WGS1/2; s>0; s=s>>1) {
130+
if (lid < s) {
131+
Add(lm[lid], lm[lid], lm[lid + s]);
132+
}
133+
barrier(CLK_LOCAL_MEM_FENCE);
134+
}
135+
#endif
114136

115137
if (lid == 0) {
116-
dot[dot_offset] = lm[0];
138+
dot[dot_offset] = lm[lid];
117139
}
118140
#endif
119141
}

0 commit comments

Comments
 (0)