Skip to content

Commit d7beb52

Browse files
committed
metal: implement cross-entropy ops for MNIST
1 parent ebc3a0f commit d7beb52

File tree

7 files changed

+372
-47
lines changed

7 files changed

+372
-47
lines changed

src/ggml-metal/ggml-metal-device.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,40 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embeddi
16511651
return res;
16521652
}
16531653

1654+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const ggml_tensor * op) {
1655+
char base[256];
1656+
char name[256];
1657+
1658+
snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(op->src[0]->type));
1659+
snprintf(name, 256, "%s", base);
1660+
1661+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1662+
if (!res.pipeline) {
1663+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1664+
}
1665+
1666+
res.smem = 32 * sizeof(float);
1667+
1668+
return res;
1669+
}
1670+
1671+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const ggml_tensor * op) {
1672+
char base[256];
1673+
char name[256];
1674+
1675+
snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[1]->type));
1676+
snprintf(name, 256, "%s", base);
1677+
1678+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1679+
if (!res.pipeline) {
1680+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1681+
}
1682+
1683+
res.smem = 32 * sizeof(float);
1684+
1685+
return res;
1686+
}
1687+
16541688
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
16551689
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
16561690

src/ggml-metal/ggml-metal-device.h

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -102,53 +102,55 @@ void ggml_metal_library_free(ggml_metal_library_t lib);
102102
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
103103
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
104104

105-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
106-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
107-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
108-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
109-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
110-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
111-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
112-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
113-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
114-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
115-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
116-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
117-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
118-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
119-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
121-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
122-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
123-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
124-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
125-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
126-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
127-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
128-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
129-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
130-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
131-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
132-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
133-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
134-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
135-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
136-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
137-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
138-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
139-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
140-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
141-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
142-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
143-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
144-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
145-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
146-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
147-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
148-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
149-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
150-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
151-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
105+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
106+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
107+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
108+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
109+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
110+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
111+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
112+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
113+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
114+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
115+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
116+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
117+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
118+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
119+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
121+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
122+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
123+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
124+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
125+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
126+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
127+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
128+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
129+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
130+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
131+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
132+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
133+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
134+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
135+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
136+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
137+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
138+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
139+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
140+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
141+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
142+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
143+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
144+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
145+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
146+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
147+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding (ggml_metal_library_t lib, const struct ggml_tensor * op);
148+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss (ggml_metal_library_t lib, const struct ggml_tensor * op);
149+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op);
150+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
151+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
152+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
153+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
152154

153155
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
154156
ggml_metal_library_t lib,

src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10301030
op->type == GGML_TYPE_I64;
10311031
case GGML_OP_ARGMAX:
10321032
return has_simdgroup_reduction;
1033+
case GGML_OP_CROSS_ENTROPY_LOSS:
1034+
return has_simdgroup_reduction;
1035+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
1036+
return has_simdgroup_reduction;
10331037
case GGML_OP_NORM:
10341038
case GGML_OP_RMS_NORM:
10351039
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));

src/ggml-metal/ggml-metal-impl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,15 @@ typedef struct {
933933
uint64_t nb01;
934934
} ggml_metal_kargs_argmax;
935935

936+
typedef struct {
937+
int32_t n_classes;
938+
int32_t n_rows;
939+
} ggml_metal_kargs_cross_entropy_loss;
940+
941+
typedef struct {
942+
int32_t n_classes;
943+
} ggml_metal_kargs_cross_entropy_loss_back;
944+
936945
typedef struct {
937946
int64_t np;
938947
} ggml_metal_kargs_opt_step_adamw;

0 commit comments

Comments
 (0)