@@ -24,38 +24,19 @@ typedef struct {
2424 uchar qs [QK_MXFP4 /2 ];
2525} block_mxfp4 ;
2626
27- // single ushort contains 4 mxfp4 as input
28- static inline half4 mxfp4_to_fp16_packed (ushort fp4x4 ) {
29- ushort2 fp16_packed_a , fp16_packed_b , bias_a , bias_b , sign_a , sign_b ;
30- fp16_packed_a .lo = (fp4x4 << 9 ) & 0x0E00 ;
31- fp16_packed_a .hi = (fp4x4 << 5 ) & 0x0E00 ;
32- fp16_packed_b .lo = (fp4x4 << 1 ) & 0x0E00 ;
33- fp16_packed_b .hi = (fp4x4 >> 3 ) & 0x0E00 ;
34-
35- bias_a .lo = (fp16_packed_a .lo == 0 ) ? 0x0 : 0x3800 ;
36- bias_a .hi = (fp16_packed_a .hi == 0 ) ? 0x0 : 0x3800 ;
37- bias_b .lo = (fp16_packed_b .lo == 0 ) ? 0x0 : 0x3800 ;
38- bias_b .hi = (fp16_packed_b .hi == 0 ) ? 0x0 : 0x3800 ;
39-
40- fp16_packed_a .lo = (fp16_packed_a .lo == 0x0200 ) ? 0x0 : fp16_packed_a .lo ;
41- fp16_packed_a .hi = (fp16_packed_a .hi == 0x0200 ) ? 0x0 : fp16_packed_a .hi ;
42- fp16_packed_b .lo = (fp16_packed_b .lo == 0x0200 ) ? 0x0 : fp16_packed_b .lo ;
43- fp16_packed_b .hi = (fp16_packed_b .hi == 0x0200 ) ? 0x0 : fp16_packed_b .hi ;
44-
45- sign_a .lo = (fp4x4 << 12 ) & 0x8000 ;
46- sign_a .hi = (fp4x4 << 8 ) & 0x8000 ;
47- sign_b .lo = (fp4x4 << 4 ) & 0x8000 ;
48- sign_b .hi = fp4x4 & 0x8000 ;
49-
50- fp16_packed_a = sign_a + bias_a + fp16_packed_a ;
51- fp16_packed_b = sign_b + bias_b + fp16_packed_b ;
52-
53- return as_half4 ((ushort4 )(fp16_packed_a , fp16_packed_b ));
54- }
27+ constant static float kvalues_mxfp4_f [16 ] = {
28+ 0 , .5f , 1.f , 1.5f , 2.f , 3.f , 4.f , 6.f , -0 , -.5f , -1.f , -1.5f , -2.f , -3.f , -4.f , -6.f
29+ };
5530
5631static inline float e8m0_to_fp32 (uchar x ) {
5732 int bits ;
58- bits = (x == 0 ) ? 0x00400000 : ((uint ) x << 23 );
33+
34+ if (x == 0 ) {
35+ bits = 0x00400000 ;
36+ } else {
37+ bits = (uint ) x << 23 ;
38+ }
39+
5940 return as_float (bits );
6041}
6142
@@ -84,8 +65,10 @@ inline void mul_mv_mxfp4_f32(
8465 int ne0 ,
8566 int ne1 ,
8667 int r2 ,
87- int r3
68+ int r3 ,
69+ local char * shmem
8870) {
71+ local float * shmem_f32 = (local float * ) shmem ;
8972 int nb = ne00 /QK_MXFP4 ;
9073
9174 int r0 = get_group_id (0 );
@@ -106,25 +89,31 @@ inline void mul_mv_mxfp4_f32(
10689 const short ix = get_sub_group_local_id ()/2 ; // 0...15
10790 const short it = get_sub_group_local_id ()%2 ; // 0 or 1
10891
92+ shmem_f32 [get_sub_group_local_id ()] = kvalues_mxfp4_f [get_sub_group_local_id ()%16 ];
93+ barrier (CLK_LOCAL_MEM_FENCE );
94+
95+ float4 yl [4 ];
10996 float sumf [N_R0_MXFP4 ] = {0.f };
11097
11198 global float * yb = y + ix * QK_MXFP4 + it * 8 ;
11299
113100 for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /2 ) {
114101 global float4 * y4 = (global float4 * )yb ;
102+ yl [0 ] = y4 [0 ];
103+ yl [1 ] = y4 [4 ];
104+ yl [2 ] = y4 [1 ];
105+ yl [3 ] = y4 [5 ];
115106
116107 for (short row = 0 ; row < N_R0_MXFP4 ; row ++ ) {
117108 global block_mxfp4 * xb = x + row * nb + ib ;
118- global ushort * q2 = (global ushort * )(xb -> qs + 8 * it );
109+ global uchar * q2 = (global uchar * )(xb -> qs + 8 * it );
110+
111+ float4 acc1 = yl [0 ]* (float4 )(shmem_f32 [q2 [0 ] & 0x0F ], shmem_f32 [q2 [1 ] & 0x0F ], shmem_f32 [q2 [2 ] & 0x0F ], shmem_f32 [q2 [3 ] & 0x0F ]);
112+ float4 acc2 = yl [1 ]* (float4 )(shmem_f32 [q2 [0 ] >> 4 ], shmem_f32 [q2 [1 ] >> 4 ], shmem_f32 [q2 [2 ] >> 4 ], shmem_f32 [q2 [3 ] >> 4 ]);
113+ float4 acc3 = yl [2 ]* (float4 )(shmem_f32 [q2 [4 ] & 0x0F ], shmem_f32 [q2 [5 ] & 0x0F ], shmem_f32 [q2 [6 ] & 0x0F ], shmem_f32 [q2 [7 ] & 0x0F ]);
114+ float4 acc4 = yl [3 ]* (float4 )(shmem_f32 [q2 [4 ] >> 4 ], shmem_f32 [q2 [5 ] >> 4 ], shmem_f32 [q2 [6 ] >> 4 ], shmem_f32 [q2 [7 ] >> 4 ]);
119115
120- half4 fp16x4_0 = mxfp4_to_fp16_packed (q2 [0 ]);
121- half4 fp16x4_1 = mxfp4_to_fp16_packed (q2 [1 ]);
122- float4 acc1 = y4 [0 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
123- acc1 += y4 [4 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
124- fp16x4_0 = mxfp4_to_fp16_packed (q2 [2 ]);
125- fp16x4_1 = mxfp4_to_fp16_packed (q2 [3 ]);
126- acc1 += y4 [1 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
127- acc1 += y4 [5 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
116+ acc1 = (acc1 + acc3 ) + (acc2 + acc4 );
128117
129118 sumf [row ] += e8m0_to_fp32 (xb -> e ) * ((acc1 .s0 + acc1 .s1 ) + (acc1 .s2 + acc1 .s3 ));
130119 }
@@ -171,7 +160,8 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
171160 int ne0 ,
172161 int ne1 ,
173162 int r2 ,
174- int r3
163+ int r3 ,
164+ local char * shmem
175165) {
176166 src0 = (global char * )((global char * )src0 + offset0 );
177167 src1 = (global char * )((global char * )src1 + offset1 );
@@ -195,5 +185,5 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
195185 global char * dst_cur = dst + (i1 * ne0 + i2 * ne1 * ne0 )* sizeof (float );
196186
197187 mul_mv_mxfp4_f32 (src0_cur , src1_cur , dst_cur ,
198- ne00 , nb01 , nb02 , nb03 , ne12 , nb11 , nb12 , nb13 , ne0 , ne1 , r2 , r3 );
188+ ne00 , nb01 , nb02 , nb03 , ne12 , nb11 , nb12 , nb13 , ne0 , ne1 , r2 , r3 , shmem );
199189}
0 commit comments