Skip to content

Commit 457338c

Browse files
committed
Added subgroup support for xnrm2.
1 parent d83ecb5 commit 457338c

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

src/kernels/level1/xnrm2.opencl

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
// literal). Comment-out this line for syntax-highlighting when developing.
1515
R"(
1616

17+
#if defined(cl_khr_work_group_uniform_arithmetic)
18+
#pragma OPENCL EXTENSION cl_khr_work_group_uniform_arithmetic : enable
19+
#elif defined(cl_khr_subgroups)
20+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
21+
#endif
22+
1723
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
1824
// this kernel file is used outside of the CLBlast library.
1925
#ifndef WGS1
@@ -61,12 +67,22 @@ void Xnrm2(const int n,
6167
output[wgid] = result;
6268
}
6369
#else
64-
for (int s=WGS1/2; s>0; s=s>>1) {
65-
if (lid < s) {
66-
Add(lm[lid], lm[lid], lm[lid + s]);
67-
}
70+
#if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups)
71+
lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]);
6872
barrier(CLK_LOCAL_MEM_FENCE);
69-
}
73+
for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) {
74+
if (lid < s) {
75+
Add(lm[lid], lm[lid], lm[lid + s]);
76+
}
77+
}
78+
#else
79+
for (int s=WGS1/2; s>0; s=s>>1) {
80+
if (lid < s) {
81+
Add(lm[lid], lm[lid], lm[lid + s]);
82+
}
83+
barrier(CLK_LOCAL_MEM_FENCE);
84+
}
85+
#endif
7086

7187
if (lid == 0) {
7288
output[wgid] = lm[0];
@@ -104,12 +120,22 @@ void Xnrm2Epilogue(const __global real* restrict input,
104120
#endif
105121
}
106122
#else
107-
for (int s=WGS1/2; s>0; s=s>>1) {
108-
if (lid < s) {
109-
Add(lm[lid], lm[lid], lm[lid + s]);
110-
}
123+
#if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups)
124+
lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]);
111125
barrier(CLK_LOCAL_MEM_FENCE);
112-
}
126+
for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) {
127+
if (lid < s) {
128+
Add(lm[lid], lm[lid], lm[lid + s]);
129+
}
130+
}
131+
#else
132+
for (int s=WGS1/2; s>0; s=s>>1) {
133+
if (lid < s) {
134+
Add(lm[lid], lm[lid], lm[lid + s]);
135+
}
136+
barrier(CLK_LOCAL_MEM_FENCE);
137+
}
138+
#endif
113139

114140
if (lid == 0) {
115141
#if PRECISION == 3232 || PRECISION == 6464

0 commit comments

Comments
 (0)