Skip to content

Commit 61cb2c5

Browse files
committed
add opt-step-adamw kernel for metal
1 parent 101b8a3 commit 61cb2c5

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,4 +773,15 @@ typedef struct {
773773
uint64_t nb01;
774774
} ggml_metal_kargs_argmax;
775775

776+
typedef struct {
777+
float alpha;
778+
float beta1;
779+
float beta2;
780+
float eps;
781+
float wd;
782+
float beta1h;
783+
float beta2h;
784+
int64_t np;
785+
} ggml_metal_kargs_opt_step_adamw;
786+
776787
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3407,5 +3407,44 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
34073407
}
34083408

34093409
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
3410+
ggml_tensor * op = ctx->node(idx);
3411+
3412+
ggml_metal_library_t lib = ctx->lib;
3413+
ggml_metal_encoder_t enc = ctx->enc;
3414+
3415+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3416+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3417+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3418+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3419+
3420+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3421+
3422+
const int64_t np = ggml_nelements(op->src[0]);
3423+
const float * params = (const float *) op->src[4]->data;
3424+
ggml_metal_kargs_opt_step_adamw args = {
3425+
/*.alpha =*/ params[0],
3426+
/*.beta1 =*/ params[1],
3427+
/*.beta2 =*/ params[2],
3428+
/*.eps =*/ params[3],
3429+
/*.wd =*/ params[4],
3430+
/*.beta1h =*/ params[5],
3431+
/*.beta2h =*/ params[6],
3432+
/*.np =*/ np,
3433+
};
3434+
3435+
int ida = 0;
3436+
3437+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3438+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3439+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
3440+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
3441+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
3442+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
3443+
3444+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3445+
const int64_t n = (np + nth - 1) / nth;
3446+
3447+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3448+
34103449
return 1;
34113450
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8754,3 +8754,36 @@ kernel void kernel_pool_2d_avg_f32(
87548754

87558755
o_ptr[cur_oh * args.OW + cur_ow] = res;
87568756
}
8757+
8758+
kernel void kernel_opt_step_adamw_f32(
8759+
constant ggml_metal_kargs_opt_step_adamw & args,
8760+
device float * x,
8761+
device const float * g,
8762+
device float * g_m,
8763+
device float * g_v,
8764+
uint gid[[thread_position_in_grid]]) {
8765+
8766+
if (gid >= args.np) {
8767+
return;
8768+
}
8769+
8770+
const float alpha = args.alpha;
8771+
const float beta1 = args.beta1;
8772+
const float beta2 = args.beta2;
8773+
const float eps = args.eps;
8774+
const float wd = args.wd;
8775+
const float beta1h = args.beta1h;
8776+
const float beta2h = args.beta2h;
8777+
8778+
const float gi = g[gid];
8779+
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
8780+
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
8781+
8782+
g_m[gid] = gmi;
8783+
g_v[gid] = gvi;
8784+
8785+
const float mh = gmi * beta1h;
8786+
const float vh = sqrt(gvi * beta2h) + eps;
8787+
8788+
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
8789+
}

0 commit comments

Comments
 (0)