diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index e61a5706d610c..c5591915128d9 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) { // fuse only ops that start with these operations // can be expanded when needed if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_NORM || node.op() == GGML_OP_RMS_NORM) { ops[0] = node.op(); @@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) { // can be expanded when needed if (gf->nodes[f]->op != GGML_OP_ADD && gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && gf->nodes[f]->op != GGML_OP_RMS_NORM) { break; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9f91662cbd876..daaf2be515a00 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { - assert(op->op == GGML_OP_RMS_NORM); - - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); - - char base[256]; - char name[256]; - - switch (n_fuse) { - case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break; - case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break; - case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break; - default: GGML_ABORT("fatal error"); - } - - snprintf(name, 256, "%s", base); - - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; - } - - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); - - return res; -} - ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); @@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) { - assert(op->op == GGML_OP_NORM); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { + assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - snprintf(base, 256, "kernel_norm_f32"); + const char * suffix = ""; + if (op->ne[0] % 4 == 0) { + suffix = "_4"; + } + + switch (op->op) { + case GGML_OP_NORM: + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break; + case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break; + case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break; + default: GGML_ABORT("fatal error"); + } break; + case GGML_OP_RMS_NORM: + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break; + case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break; + case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + } + snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index da67bfab758e8..dda7eca85a7e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 67f71ace2b444..5f744d1a0bd96 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); - case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: return has_simdgroup_reduction; case GGML_OP_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_RMS_NORM: + return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); case GGML_OP_ROPE: return true; case GGML_OP_IM2COL: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 3a7a4f317c638..ab51b76c8cd36 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -428,16 +428,11 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_mul_mv_id; +// NORM +// RMS_NORM typedef struct { int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; + int32_t ne00_t; uint64_t nb1; uint64_t nb2; uint64_t nb3; @@ -448,7 +443,7 @@ typedef struct { uint64_t nbf1[3]; uint64_t nbf2[3]; uint64_t nbf3[3]; -} ggml_metal_kargs_rms_norm; +} ggml_metal_kargs_norm; typedef struct { int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3b163d9a38e75..8059af951ae94 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -266,10 +266,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_set_rows(ctx, idx); } break; - case GGML_OP_RMS_NORM: - { - n_fuse = ggml_metal_op_rms_norm(ctx, idx); - } break; case GGML_OP_L2_NORM: { n_fuse = ggml_metal_op_l2_norm(ctx, idx); @@ -279,6 +275,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_group_norm(ctx, idx); } break; case GGML_OP_NORM: + case GGML_OP_RMS_NORM: { n_fuse = ggml_metal_op_norm(ctx, idx); } break; @@ -2346,146 +2343,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { return n_fuse; } -int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - const int idx_end = ctx->idx_end; - - const bool use_fusion = ctx->use_fusion; - - const int debug_fusion = ctx->debug_fusion; - - ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); - - float eps; - memcpy(&eps, op->op_params, sizeof(float)); - - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - - ggml_metal_kargs_rms_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.eps =*/ eps, - /*.nef1 =*/ { ne01 }, - /*.nef2 =*/ { ne02 }, - /*.nef3 =*/ { ne03 }, - /*.nbf1 =*/ { nb01 }, - /*.nbf2 =*/ { nb02 }, - /*.nbf3 =*/ { nb03 }, - }; - - ggml_op fops[8]; - - int n_fuse = 1; - - ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; - - // d[0] = rms_norm(a) - // d[1] = mul(d[0], b) - // d[2] = add(d[1], c) - if (use_fusion) { - fops[0] = GGML_OP_RMS_NORM; - fops[1] = GGML_OP_MUL; - fops[2] = GGML_OP_ADD; - - for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { - break; - } - - if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { - break; - } - - if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { - break; - } - - if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { - break; - } - - if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { - break; - } - - //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; - - bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); - - args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; - args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; - args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; - - args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; - args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; - args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; - } - - ++n_fuse; - - if (debug_fusion > 1 && n_fuse > 1) { - if (n_fuse == 2) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); - } - if (n_fuse == 3) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); - } - } - } - - if (n_fuse > 1) { - bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); - - for (int i = 1; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { - ggml_metal_op_concurrency_reset(ctx); - - break; - } - } - } - - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; - } - - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, bid_src0, 1); - ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); - ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); - ggml_metal_encoder_set_buffer (enc, bid_dst, 4); - - ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); - - return n_fuse; -} - int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { ggml_cgraph * gf = ctx->gf; ggml_tensor * op = ggml_graph_node(gf, idx); @@ -2594,6 +2451,14 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; + const int idx_end = ctx->idx_end; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); @@ -2602,37 +2467,121 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { float eps; memcpy(&eps, op->op_params, sizeof(float)); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_norm args = { /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, + /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op); + ggml_op fops[8]; + + int n_fuse = 1; + + ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; + + // d[0] = norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (use_fusion) { + fops[0] = op->op; + fops[1] = GGML_OP_MUL; + fops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + break; + } + + if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { + break; + } + + if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + + args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; + args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; + args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op)); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op)); + } + } + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); int nth = 32; // SIMD width - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + + while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); + nth = std::min(nth, args.ne00_t); const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); - return 1; + return n_fuse; } int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index b620de164d755..a1151f881b372 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -60,7 +60,6 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2ba4cb50b9ec2..11920a3d19b9b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -66,6 +66,10 @@ static inline float e8m0_to_fp32(uint8_t x) { return as_type(bits); } +static inline float dot(float x, float y) { + return x*y; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -2493,30 +2497,43 @@ kernel void kernel_argmax_f32( dst_i32[tgpig] = arg_val; } -kernel void kernel_norm_f32( +// F == 1 : norm (no fuse) +// F == 2 : norm + mul +// F == 3 : norm + mul + add +template +kernel void kernel_norm_fuse_impl( constant ggml_metal_kargs_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; - float4 sumf4(0.0f); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + + T sumft(0.0f); float sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - sumf4 += x[i00]; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + sumft += x[i00]; } - sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = dot(sumft, T(1.0f)); sumf = simd_sum(sumf); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2532,10 +2549,10 @@ kernel void kernel_norm_f32( const float mean = sumf/args.ne00; - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { y[i00] = x[i00] - mean; sumf += dot(y[i00], y[i00]); } @@ -2555,17 +2572,35 @@ kernel void kernel_norm_f32( const float variance = sumf/args.ne00; const float scale = 1.0f/sqrt(variance + args.eps); - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = y[i00] * scale; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + if (F == 1) { + y[i00] = (y[i00]*scale); + } + if (F == 2) { + y[i00] = (y[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (y[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_norm_fuse_impl) kernel_norm_fuse_t; + +template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + +template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + // F == 1 : rms_norm (no fuse) // F == 2 : rms_norm + mul // F == 3 : rms_norm + mul + add -template +template kernel void kernel_rms_norm_fuse_impl( - constant ggml_metal_kargs_rms_norm & args, + constant ggml_metal_kargs_norm & args, device const char * src0, device const char * src1_0, device const char * src1_1, @@ -2584,15 +2619,15 @@ kernel void kernel_rms_norm_fuse_impl( const int i02 = tgpig.y; const int i03 = tgpig.z; - device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); - device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); - device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2611,8 +2646,8 @@ kernel void kernel_rms_norm_fuse_impl( const float mean = sumf/args.ne00; const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { if (F == 1) { y[i00] = (x[i00]*scale); } @@ -2625,11 +2660,15 @@ kernel void kernel_rms_norm_fuse_impl( } } -typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; +typedef decltype(kernel_rms_norm_fuse_impl) kernel_rms_norm_fuse_t; + +template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; -template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; -template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; kernel void kernel_l2_norm_f32( constant ggml_metal_kargs_l2_norm & args, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 592631f3ed21a..fdc90ca0697dc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6117,7 +6117,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { - test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));