Skip to content

Commit 101b8a3

Browse files
committed
scaffold to support opt step adamw on metal (not written so far)
1 parent a3cb047 commit 101b8a3

File tree

5 files changed

+30
-0
lines changed

5 files changed

+30
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,3 +1482,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
14821482
return res;
14831483
}
14841484

1485+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1486+
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1487+
1488+
char base[256];
1489+
char name[256];
1490+
1491+
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1492+
snprintf(name, 256, "%s", base);
1493+
1494+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1495+
if (res) {
1496+
return res;
1497+
}
1498+
1499+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1500+
1501+
return res;
1502+
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
134134
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
135135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
137+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
137138

138139
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139140
ggml_metal_library_t lib,

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
798798
return false;
799799
};
800800
}
801+
case GGML_OP_OPT_STEP_ADAMW:
802+
return true;
801803
default:
802804
return false;
803805
}

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
410410
{
411411
n_fuse = ggml_metal_op_argmax(ctx, idx);
412412
} break;
413+
case GGML_OP_OPT_STEP_ADAMW:
414+
{
415+
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
416+
} break;
413417
default:
414418
{
415419
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -3401,3 +3405,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
34013405

34023406
return 1;
34033407
}
3408+
3409+
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
3410+
return 1;
3411+
}

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
7878
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
7979
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
8080
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
81+
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
8182

8283
#ifdef __cplusplus
8384
}

0 commit comments

Comments
 (0)