Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);

char base[256];
char name[256];

snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));

Expand Down Expand Up @@ -1482,3 +1501,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);

char base[256];
char name[256];

snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

return res;
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand All @@ -134,6 +135,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
Expand Down Expand Up @@ -798,6 +799,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_OPT_STEP_ADAMW:
return has_simdgroup_reduction;
default:
return false;
}
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;

typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;

typedef struct {
int64_t ne00;
int64_t ne01;
Expand Down Expand Up @@ -773,4 +777,8 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;

typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;

#endif // GGML_METAL_IMPL
68 changes: 68 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_glu(ctx, idx);
} break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
Expand Down Expand Up @@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
Expand Down Expand Up @@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);

ggml_metal_kargs_sum args = {
/*.np =*/ n,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);

return 1;
}

int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down Expand Up @@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {

return 1;
}

int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);

const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};

int ida = 0;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);

const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;

ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);

return 1;
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
Expand Down Expand Up @@ -78,6 +79,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);

#ifdef __cplusplus
}
Expand Down
52 changes: 52 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
}
}

kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {

if (tiitg != 0) {
return;
}

float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}

dst[0] = acc;
}

template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
Expand Down Expand Up @@ -8754,3 +8772,37 @@ kernel void kernel_pool_2d_avg_f32(

o_ptr[cur_oh * args.OW + cur_ow] = res;
}

kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {

if (gid >= args.np) {
return;
}

const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];

const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);

g_m[gid] = gmi;
g_v[gid] = gvi;

const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;

x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
Loading