|
14 | 14 | // literal). Comment-out this line for syntax-highlighting when developing. |
15 | 15 | R"( |
16 | 16 |
|
| 17 | +#if defined(cl_khr_work_group_uniform_arithmetic) |
| 18 | +#pragma OPENCL EXTENSION cl_khr_work_group_uniform_arithmetic : enable |
| 19 | +#elif defined(cl_khr_subgroups) |
| 20 | +#pragma OPENCL EXTENSION cl_khr_subgroups : enable |
| 21 | +#endif |
| 22 | + |
17 | 23 | // Parameters set by the tuner or by the database. Here they are given a basic default value in case |
18 | 24 | // this kernel file is used outside of the CLBlast library. |
19 | 25 | #ifndef WGS1 |
@@ -61,12 +67,22 @@ void Xnrm2(const int n, |
61 | 67 | output[wgid] = result; |
62 | 68 | } |
63 | 69 | #else |
64 | | - for (int s=WGS1/2; s>0; s=s>>1) { |
65 | | - if (lid < s) { |
66 | | - Add(lm[lid], lm[lid], lm[lid + s]); |
67 | | - } |
| 70 | + #if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups) |
| 71 | + lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]); |
68 | 72 | barrier(CLK_LOCAL_MEM_FENCE); |
69 | | - } |
| 73 | + for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) { |
| 74 | + if (lid < s) { |
| 75 | + Add(lm[lid], lm[lid], lm[lid + s]); |
| 76 | + } |
| 77 | + } |
| 78 | + #else |
| 79 | + for (int s=WGS1/2; s>0; s=s>>1) { |
| 80 | + if (lid < s) { |
| 81 | + Add(lm[lid], lm[lid], lm[lid + s]); |
| 82 | + } |
| 83 | + barrier(CLK_LOCAL_MEM_FENCE); |
| 84 | + } |
| 85 | + #endif |
70 | 86 |
|
71 | 87 | if (lid == 0) { |
72 | 88 | output[wgid] = lm[0]; |
@@ -104,12 +120,22 @@ void Xnrm2Epilogue(const __global real* restrict input, |
104 | 120 | #endif |
105 | 121 | } |
106 | 122 | #else |
107 | | - for (int s=WGS1/2; s>0; s=s>>1) { |
108 | | - if (lid < s) { |
109 | | - Add(lm[lid], lm[lid], lm[lid + s]); |
110 | | - } |
| 123 | + #if defined(cl_khr_subgroups) || defined(__opencl_c_subgroups) |
| 124 | + lm[get_sub_group_local_id()] = sub_group_reduce_add(lm[lid]); |
111 | 125 | barrier(CLK_LOCAL_MEM_FENCE); |
112 | | - } |
| 126 | + for (int s = get_num_sub_groups() >> 1; s > 0; s >>= 1) { |
| 127 | + if (lid < s) { |
| 128 | + Add(lm[lid], lm[lid], lm[lid + s]); |
| 129 | + } |
| 130 | + } |
| 131 | + #else |
| 132 | + for (int s=WGS1/2; s>0; s=s>>1) { |
| 133 | + if (lid < s) { |
| 134 | + Add(lm[lid], lm[lid], lm[lid + s]); |
| 135 | + } |
| 136 | + barrier(CLK_LOCAL_MEM_FENCE); |
| 137 | + } |
| 138 | + #endif |
113 | 139 |
|
114 | 140 | if (lid == 0) { |
115 | 141 | #if PRECISION == 3232 || PRECISION == 6464 |
|
0 commit comments