@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
134134 src1 = src1 + offset1 ;
135135 dst = dst + offsetd ;
136136
137+ // The size of sum is sizeof(float)*subgroup_size.
138+ // Each subgroup writes its partial sum to this array.
139+ // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
140+ // This is generally true -
141+ // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
142+ if (get_sub_group_id () == 0 ) {
143+ sum [get_sub_group_local_id ()] = 0.0f ;
144+ }
145+
137146 int i03 = get_group_id (2 );
138147 int i02 = get_group_id (1 );
139148 int i01 = get_group_id (0 );
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
148157 sumf += dot (x [i00 ], x [i00 ]);
149158 }
150159 sumf = sub_group_reduce_add (sumf );
160+
161+ barrier (CLK_LOCAL_MEM_FENCE );
162+
151163 if (get_sub_group_local_id () == 0 ) {
152164 sum [get_sub_group_id ()] = sumf ;
153165 }
154166
155167 barrier (CLK_LOCAL_MEM_FENCE );
156168
157- for (uint i = get_local_size (0 ) / get_max_sub_group_size () / 2 ; i > 0 ; i /= 2 ) {
158- if (get_local_id (0 ) < i ) {
159- sum [get_local_id (0 )] += sum [get_local_id (0 ) + i ];
160- }
161- }
162- if (get_local_id (0 ) == 0 ) {
163- sum [0 ] /= ne00 ;
164- }
169+ // for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
170+ // if (get_local_id(0) < i) {
171+ // sum[get_local_id(0)] += sum[get_local_id(0) + i];
172+ // }
173+ // }
174+ // if (get_local_id(0) == 0) {
175+ // sum[0] /= ne00;
176+ // }
165177
166- barrier (CLK_LOCAL_MEM_FENCE );
178+ //barrier(CLK_LOCAL_MEM_FENCE);
179+
180+ sumf = sum [get_sub_group_local_id ()];
181+ sumf = sub_group_reduce_add (sumf );
167182
168- float mean = sum [ 0 ] ;
183+ float mean = sumf / ne00 ;
169184 float scale = 1.0f /sqrt (mean + eps );
170185
171186 global float4 * y = (global float4 * ) (dst + i03 * nb3 + i02 * nb2 + i01 * nb1 );
0 commit comments