1414// literal). Comment-out this line for syntax-highlighting when developing.
1515R "(
1616
17- #ifdef cl_khr_work_group_uniform_arithmetic
17+ #if defined( cl_khr_work_group_uniform_arithmetic )
1818#pragma OPENCL EXTENSION cl_khr_work_group_uniform_arithmetic : enable
19+ #elif defined(cl_khr_subgroups )
20+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
1921#endif
2022
2123// Parameters set by the tuner or by the database. Here they are given a basic default value in case
@@ -66,12 +68,22 @@ void Xdot(const int n,
6668 output [wgid ] = result ;
6769 }
6870 #else
69- for (int s = WGS1 /2 ; s > 0 ; s = s >>1 ) {
70- if (lid < s ) {
71- Add (lm [lid ], lm [lid ], lm [lid + s ]);
72- }
71+ #if defined(cl_khr_subgroups ) || defined(__opencl_c_subgroups )
72+ lm [get_sub_group_local_id ()] = sub_group_reduce_add (lm [lid ]);
7373 barrier (CLK_LOCAL_MEM_FENCE );
74- }
74+ for (int s = get_num_sub_groups () >> 1 ; s > 0 ; s >>= 1 ) {
75+ if (lid < s ) {
76+ Add (lm [lid ], lm [lid ], lm [lid + s ]);
77+ }
78+ }
79+ #else
80+ for (int s = WGS1 /2 ; s > 0 ; s = s >>1 ) {
81+ if (lid < s ) {
82+ Add (lm [lid ], lm [lid ], lm [lid + s ]);
83+ }
84+ barrier (CLK_LOCAL_MEM_FENCE );
85+ }
86+ #endif
7587
7688 if (lid == 0 ) {
7789 output [wgid ] = lm [0 ];
@@ -97,23 +109,33 @@ void XdotEpilogue(const __global real* restrict input,
97109 Add (lm [lid ], input [lid ], input [lid + WGS2 ]);
98110 barrier (CLK_LOCAL_MEM_FENCE );
99111
100- // Performs reduction in local memory and stores the per work group result
112+ // Performs reduction in local memory and stores final result
101113 #if defined(cl_khr_work_group_uniform_arithmetic ) || defined(__opencl_c_work_group_collective_functions )
102114 real result = work_group_reduce_add (lm [lid ])
103115
104116 if (lid == 0 ) {
105117 dot [dot_offset ] = result ;
106118 }
107119 #else
108- for (int s = WGS1 /2 ; s > 0 ; s = s >>1 ) {
109- if (lid < s ) {
110- Add (lm [lid ], lm [lid ], lm [lid + s ]);
111- }
120+ #if defined(cl_khr_subgroups ) || defined(__opencl_c_subgroups )
121+ lm [get_sub_group_local_id ()] = sub_group_reduce_add (lm [lid ]);
112122 barrier (CLK_LOCAL_MEM_FENCE );
113- }
123+ for (int s = get_num_sub_groups () >> 1 ; s > 0 ; s >>= 1 ) {
124+ if (lid < s ) {
125+ Add (lm [lid ], lm [lid ], lm [lid + s ]);
126+ }
127+ }
128+ #else
129+ for (int s = WGS1 /2 ; s > 0 ; s = s >>1 ) {
130+ if (lid < s ) {
131+ Add (lm [lid ], lm [lid ], lm [lid + s ]);
132+ }
133+ barrier (CLK_LOCAL_MEM_FENCE );
134+ }
135+ #endif
114136
115137 if (lid == 0 ) {
116- dot [dot_offset ] = lm [0 ];
138+ dot [dot_offset ] = lm [lid ];
117139 }
118140 #endif
119141}
0 commit comments