@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
66#endif // RTE16
77
88#include "types.comp"
9- #include "generic_unary_head.comp"
109
11- #if defined(DATA_A_IQ4_NL)
12- // 16 invocations needed for init_iq4nl_shmem
13- layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in ;
10+ #if defined(SET_ROWS) && QUANT_K == 1
11+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
12+ const uint BLOCK_SIZE = 512 ;
1413#else
15- layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
14+ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
15+ const uint BLOCK_SIZE = 32;
1616#endif
1717
1818layout (binding = 0) readonly buffer S {float data_s[];};
19+
20+ #if defined(SET_ROWS)
21+ #include "generic_binary_head.comp"
22+ layout (binding = 1) readonly buffer C {uvec2 data_i[];};
23+ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
24+ #else
25+ #include "generic_unary_head.comp"
1926layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
27+ #endif
2028
2129#if defined(DATA_A_Q4_0)
2230void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
221229}
222230#endif
223231
232+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
233+ void quantize(uint dst_idx, uint src_idx)
234+ {
235+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
236+ }
237+ #endif
238+
239+ #if defined(DATA_A_BF16)
240+ void quantize(uint dst_idx, uint src_idx)
241+ {
242+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
243+ }
244+ #endif
245+
246+ #if defined(SET_ROWS)
247+
224248void main() {
225249#ifdef NEEDS_INIT_IQ_SHMEM
226250 init_iq_shmem(gl_WorkGroupSize);
227- if (gl_LocalInvocationIndex.x != 0) {
251+ #endif
252+
253+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
254+
255+ if (idx >= p.ne) {
228256 return;
229257 }
258+
259+ uint i00, i01, i02, i03;
260+ get_indices(idx, i00, i01, i02, i03);
261+
262+ uint i12 = fastmod(i03, p.ne12);
263+ uint i11 = fastmod(i02, p.ne11);
264+ uint i10 = i01;
265+
266+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
267+
268+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
269+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
270+
271+ quantize(dst_idx, src0_idx);
272+ }
273+
274+ #else
275+
276+ void main() {
277+ #ifdef NEEDS_INIT_IQ_SHMEM
278+ init_iq_shmem(gl_WorkGroupSize);
230279#endif
231280
232- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
281+ const uint idx = ( gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
233282
234283 if (idx >= p.ne) {
235284 return;
@@ -240,3 +289,5 @@ void main() {
240289
241290 quantize(dst_idx, src_idx);
242291}
292+
293+ #endif
0 commit comments