Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,15 +773,4 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;

typedef struct {
float alpha;
float beta1;
float beta2;
float eps;
float wd;
float beta1h;
float beta2h;
int64_t np;
} ggml_metal_kargs_opt_step_adamw;

#endif // GGML_METAL_IMPL
16 changes: 2 additions & 14 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3419,28 +3419,16 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);

const int64_t np = ggml_nelements(op->src[0]);
const float * params = (const float *) op->src[4]->data;
ggml_metal_kargs_opt_step_adamw args = {
/*.alpha =*/ params[0],
/*.beta1 =*/ params[1],
/*.beta2 =*/ params[2],
/*.eps =*/ params[3],
/*.wd =*/ params[4],
/*.beta1h =*/ params[5],
/*.beta2h =*/ params[6],
/*.np =*/ np,
};

int ida = 0;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);

const int64_t np = ggml_nelements(op->src[0]);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;

Expand Down
20 changes: 8 additions & 12 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -8756,24 +8756,20 @@ kernel void kernel_pool_2d_avg_f32(
}

kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {

if (gid >= args.np) {
return;
}

const float alpha = args.alpha;
const float beta1 = args.beta1;
const float beta2 = args.beta2;
const float eps = args.eps;
const float wd = args.wd;
const float beta1h = args.beta1h;
const float beta2h = args.beta2h;
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];

const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
Expand Down