@@ -24,15 +24,6 @@ typedef struct {
2424 uchar qs [QK_MXFP4 /2 ];
2525} block_mxfp4 ;
2626
27-
28- // static inline half mxfp4_to_fp16(uchar fp4) {
29- // ushort sign = (fp4 >> 3) & 0x1;
30- // ushort d = (fp4 >> 2) & 0x1;
31- // ushort a = fp4 & 0x7;
32-
33- // return (1 - sign * 2) * ((1-d) * a * 0.5f + d * (ushort)(1.2999f * a - 3.0799f));
34- // }
35-
3627// single ushort contains 4 mxfp4 as input
3728static inline half4 mxfp4_to_fp16_packed (ushort fp4x4 ) {
3829 ushort2 fp16_packed_a , fp16_packed_b , bias_a , bias_b , sign_a , sign_b ;
@@ -93,10 +84,8 @@ inline void mul_mv_mxfp4_f32(
9384 int ne0 ,
9485 int ne1 ,
9586 int r2 ,
96- int r3 ,
97- local char * shmem
87+ int r3
9888) {
99- // local float * shmem_f32 = (local float *) shmem;
10089 int nb = ne00 /QK_MXFP4 ;
10190
10291 int r0 = get_group_id (0 );
@@ -117,35 +106,25 @@ inline void mul_mv_mxfp4_f32(
117106 const short ix = get_sub_group_local_id ()/2 ; // 0...15
118107 const short it = get_sub_group_local_id ()%2 ; // 0 or 1
119108
120- float4 yl [4 ];
121109 float sumf [N_R0_MXFP4 ] = {0.f };
122110
123111 global float * yb = y + ix * QK_MXFP4 + it * 8 ;
124112
125113 for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /2 ) {
126114 global float4 * y4 = (global float4 * )yb ;
127- yl [0 ] = y4 [0 ];
128- yl [1 ] = y4 [4 ];
129- yl [2 ] = y4 [1 ];
130- yl [3 ] = y4 [5 ];
131115
132116 for (short row = 0 ; row < N_R0_MXFP4 ; row ++ ) {
133117 global block_mxfp4 * xb = x + row * nb + ib ;
134118 global ushort * q2 = (global ushort * )(xb -> qs + 8 * it );
135119
136120 half4 fp16x4_0 = mxfp4_to_fp16_packed (q2 [0 ]);
137121 half4 fp16x4_1 = mxfp4_to_fp16_packed (q2 [1 ]);
138- float4 acc1 = yl [0 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
139- acc1 += yl [ 1 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
122+ float4 acc1 = y4 [0 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
123+ acc1 += y4 [ 4 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
140124 fp16x4_0 = mxfp4_to_fp16_packed (q2 [2 ]);
141125 fp16x4_1 = mxfp4_to_fp16_packed (q2 [3 ]);
142- acc1 += yl [2 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
143- acc1 += yl [3 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
144-
145- // float4 acc1 = yl[0]*(float4)(mxfp4_to_fp16(q2[0] & 0x0F), mxfp4_to_fp16(q2[1] & 0x0F), mxfp4_to_fp16(q2[2] & 0x0F), mxfp4_to_fp16(q2[3] & 0x0F));
146- // acc1 += yl[1]*(float4)(mxfp4_to_fp16(q2[0] >> 4 ), mxfp4_to_fp16(q2[1] >> 4 ), mxfp4_to_fp16(q2[2] >> 4 ), mxfp4_to_fp16(q2[3] >> 4 ));
147- // acc1 += yl[2]*(float4)(mxfp4_to_fp16(q2[4] & 0x0F), mxfp4_to_fp16(q2[5] & 0x0F), mxfp4_to_fp16(q2[6] & 0x0F), mxfp4_to_fp16(q2[7] & 0x0F));
148- // acc1 += yl[3]*(float4)(mxfp4_to_fp16(q2[4] >> 4 ), mxfp4_to_fp16(q2[5] >> 4 ), mxfp4_to_fp16(q2[6] >> 4 ), mxfp4_to_fp16(q2[7] >> 4 ));
126+ acc1 += y4 [1 ]* (float4 )(fp16x4_0 .s0 , fp16x4_0 .s2 , fp16x4_1 .s0 , fp16x4_1 .s2 );
127+ acc1 += y4 [5 ]* (float4 )(fp16x4_0 .s1 , fp16x4_0 .s3 , fp16x4_1 .s1 , fp16x4_1 .s3 );
149128
150129 sumf [row ] += e8m0_to_fp32 (xb -> e ) * ((acc1 .s0 + acc1 .s1 ) + (acc1 .s2 + acc1 .s3 ));
151130 }
@@ -192,8 +171,7 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
192171 int ne0 ,
193172 int ne1 ,
194173 int r2 ,
195- int r3 ,
196- local char * shmem
174+ int r3
197175) {
198176 src0 = (global char * )((global char * )src0 + offset0 );
199177 src1 = (global char * )((global char * )src1 + offset1 );
@@ -214,14 +192,8 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
214192 global char * src0_cur = src0 + i02 * nb02 ;
215193 global char * src1_cur = src1 + i11 * nb11 + i12 * nb12 ;
216194
217- // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
218- // printf("[kernel_mul_mv_id_mxfp4_f32_flat] src1(%lu): %f, src2(%lu): %d\n", offset1, ((global float*)src1)[0], offset2, ((global int*)src2)[0]);
219- // global block_mxfp4 * block = (global block_mxfp4 *)(src0);
220- // printf("[kernel_mul_mv_id_mxfp4_f32] i02: %d, offset0: %d, e: %d, q[0]: %d, q[16]: %d\n", i02, offset0, block->e, block->qs[0], block->qs[15]);
221- // }
222-
223195 global char * dst_cur = dst + (i1 * ne0 + i2 * ne1 * ne0 )* sizeof (float );
224196
225197 mul_mv_mxfp4_f32 (src0_cur , src1_cur , dst_cur ,
226- ne00 , nb01 , nb02 , nb03 , ne12 , nb11 , nb12 , nb13 , ne0 , ne1 , r2 , r3 , shmem );
198+ ne00 , nb01 , nb02 , nb03 , ne12 , nb11 , nb12 , nb13 , ne0 , ne1 , r2 , r3 );
227199}
0 commit comments