Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
// fuse only ops that start with these operations
// can be expanded when needed
if (node.op() == GGML_OP_ADD ||
node.op() == GGML_OP_NORM ||
node.op() == GGML_OP_RMS_NORM) {
ops[0] = node.op();

Expand All @@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
// can be expanded when needed
if (gf->nodes[f]->op != GGML_OP_ADD &&
gf->nodes[f]->op != GGML_OP_MUL &&
gf->nodes[f]->op != GGML_OP_NORM &&
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
break;
}
Expand Down
61 changes: 26 additions & 35 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
assert(op->op == GGML_OP_RMS_NORM);

GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));

char base[256];
char name[256];

switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
default: GGML_ABORT("fatal error");
}

snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

ggml_metal_pipeline_set_smem(res, 32*sizeof(float));

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);

Expand Down Expand Up @@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_NORM);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);

GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));

char base[256];
char name[256];

snprintf(base, 256, "kernel_norm_f32");
const char * suffix = "";
if (op->ne[0] % 4 == 0) {
suffix = "_4";
}

switch (op->op) {
case GGML_OP_NORM:
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
default: GGML_ABORT("fatal error");
} break;
case GGML_OP_RMS_NORM:
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
}

snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
Expand Down
13 changes: 4 additions & 9 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,11 @@ typedef struct {
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;

// NORM
// RMS_NORM
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_norm;

typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne00_t;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
Expand All @@ -448,7 +443,7 @@ typedef struct {
uint64_t nbf1[3];
uint64_t nbf2[3];
uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;
} ggml_metal_kargs_norm;

typedef struct {
int32_t ne00;
Expand Down
Loading