@@ -1194,70 +1194,6 @@ kernel void kernel_neg(
11941194 dst[tpig] = -src0[tpig];
11951195}
11961196
1197- kernel void kernel_reglu (
1198- device const char * src0,
1199- device const char * src1,
1200- device char * dst,
1201- constant ggml_metal_kargs_glu & args,
1202- uint tgpig[[threadgroup_position_in_grid]],
1203- uint tpitg[[thread_position_in_threadgroup]],
1204- uint ntg[[threads_per_threadgroup]]) {
1205- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1206- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1207- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1208-
1209- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1210- const float x0 = src0_row[i0];
1211- const float x1 = src1_row[i0];
1212-
1213- dst_row[i0] = x0*x1*(x0 > 0 .0f );
1214- }
1215- }
1216-
1217- kernel void kernel_geglu (
1218- device const char * src0,
1219- device const char * src1,
1220- device char * dst,
1221- constant ggml_metal_kargs_glu & args,
1222- uint tgpig[[threadgroup_position_in_grid]],
1223- uint tpitg[[thread_position_in_threadgroup]],
1224- uint ntg[[threads_per_threadgroup]]) {
1225- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1226- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1227- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1228-
1229- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1230- const float x0 = src0_row[i0];
1231- const float x1 = src1_row[i0];
1232-
1233- const float gelu = 0 .5f *x0*(1 .0f + precise::tanh (SQRT_2_OVER_PI*x0*(1 .0f + GELU_COEF_A*x0*x0)));
1234-
1235- dst_row[i0] = gelu*x1;
1236- }
1237- }
1238-
1239- kernel void kernel_swiglu (
1240- device const char * src0,
1241- device const char * src1,
1242- device char * dst,
1243- constant ggml_metal_kargs_glu & args,
1244- uint tgpig[[threadgroup_position_in_grid]],
1245- uint tpitg[[thread_position_in_threadgroup]],
1246- uint ntg[[threads_per_threadgroup]]) {
1247- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1248- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1249- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1250-
1251- for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
1252- const float x0 = src0_row[i0];
1253- const float x1 = src1_row[i0];
1254-
1255- const float silu = x0 / (1 .0f + exp (-x0));
1256-
1257- dst_row[i0] = silu*x1;
1258- }
1259- }
1260-
12611197template <bool norm>
12621198kernel void kernel_sum_rows (
12631199 constant ggml_metal_kargs_sum_rows & args,
@@ -1298,14 +1234,7 @@ kernel void kernel_sum_rows(
12981234 shmem_f32[sgitg] = sumf;
12991235 }
13001236
1301- threadgroup_barrier (mem_flags::mem_threadgroup);
1302-
1303- sumf = shmem_f32[tiisg];
1304- sumf = simd_sum (sumf);
1305-
1306- if (tpitg.x == 0 ) {
1307- dst_row[0 ] = norm ? sumf / args.ne00 : sumf;
1308- }
1237+ dst_row[0 ] = row_sum;
13091238}
13101239
13111240typedef decltype (kernel_sum_rows<false >) kernel_sum_rows_t;
@@ -4807,10 +4736,51 @@ kernel void kernel_cpy_f32_q5_1(
48074736 for (int64_t i00 = tpitg.x *QK5_1; i00 < args.ne00 ; i00 += ntg.x *QK5_1) {
48084737 device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00 );
48094738
4810- quantize_q5_1 (src, dst_data[i00/QK5_1]);
4739+ float max = src[0 ];
4740+ float min = src[0 ];
4741+
4742+ for (int j = 1 ; j < QK5_1; j++) {
4743+ const float v = src[j];
4744+ min = v < min ? v : min;
4745+ max = v > max ? v : max;
4746+ }
4747+
4748+ const float d = (max - min) / 31 ;
4749+ const float id = d ? 1 .0f /d : 0 .0f ;
4750+
4751+ dst_data[i00/QK5_1].d = d;
4752+ dst_data[i00/QK5_1].m = min;
4753+
4754+ uint32_t qh = 0 ;
4755+ for (int j = 0 ; j < QK5_1/2 ; ++j) {
4756+ const float x0 = (src[0 + j] - min)*id;
4757+ const float x1 = (src[QK5_1/2 + j] - min)*id;
4758+
4759+ const uint8_t xi0 = (uint8_t )(x0 + 0 .5f );
4760+ const uint8_t xi1 = (uint8_t )(x1 + 0 .5f );
4761+
4762+ dst_data[i00/QK5_1].qs [j] = (xi0 & 0xf ) | ((xi1 & 0xf ) << 4 );
4763+ qh |= ((xi0 & 0x10u ) >> 4 ) << (j + 0 );
4764+ qh |= ((xi1 & 0x10u ) >> 4 ) << (j + QK5_1/2 );
4765+ }
4766+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4767+ for (int j = 0 ; j < 4 ; ++j) {
4768+ dst_data[i00/QK5_1].qh [j] = qh8[j];
4769+ }
48114770 }
48124771}
48134772
4773+ static inline int best_index_int8 (int n, constant float * val, float x) {
4774+ if (x <= val[0 ]) return 0 ;
4775+ if (x >= val[n-1 ]) return n-1 ;
4776+ int ml = 0 , mu = n-1 ;
4777+ while (mu-ml > 1 ) {
4778+ int mav = (ml+mu)/2 ;
4779+ if (x < val[mav]) mu = mav; else ml = mav;
4780+ }
4781+ return x - val[mu-1 ] < val[mu] - x ? mu-1 : mu;
4782+ }
4783+
48144784kernel void kernel_cpy_f32_iq4_nl (
48154785 constant ggml_metal_kargs_cpy & args,
48164786 device const char * src0,
0 commit comments