Skip to content

Commit 48fb04b

Browse files
committed
metal: implement cross-entropy ops for MNIST
1 parent f7cb4b7 commit 48fb04b

File tree

7 files changed

+319
-94
lines changed

7 files changed

+319
-94
lines changed

src/ggml-metal/ggml-metal-device.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embeddi
16761676
return res;
16771677
}
16781678

1679+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const ggml_tensor * op) {
1680+
char base[256];
1681+
char name[256];
1682+
1683+
snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(op->src[0]->type));
1684+
snprintf(name, 256, "%s", base);
1685+
1686+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1687+
if (!res.pipeline) {
1688+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1689+
}
1690+
1691+
return res;
1692+
}
1693+
1694+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const ggml_tensor * op) {
1695+
char base[256];
1696+
char name[256];
1697+
1698+
snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[1]->type));
1699+
snprintf(name, 256, "%s", base);
1700+
1701+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1702+
if (!res.pipeline) {
1703+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1704+
}
1705+
1706+
return res;
1707+
}
1708+
16791709
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
16801710
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
16811711

src/ggml-metal/ggml-metal-device.h

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -102,54 +102,56 @@ void ggml_metal_library_free(ggml_metal_library_t lib);
102102
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
103103
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
104104

105-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
106-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
107-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
108-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
109-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
110-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
111-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
112-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
113-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
114-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
115-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
116-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
117-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
118-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
119-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
120-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
121-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
122-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
123-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
124-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
125-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
126-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
127-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
128-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
129-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
130-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
131-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
132-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
133-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
134-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
135-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
136-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
137-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
138-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
139-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
140-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
141-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
142-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
143-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
144-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
145-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
146-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
147-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
148-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
149-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
150-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
151-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
152-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
105+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
106+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
107+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
108+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
109+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
110+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
111+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
112+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
113+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
114+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
115+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
116+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
117+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
118+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
119+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
120+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
121+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
122+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
123+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
124+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
125+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
126+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
127+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
128+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
129+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
130+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
131+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
132+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
133+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
134+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
135+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
136+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
137+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
138+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
139+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
140+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
141+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
142+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
143+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
144+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
145+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
146+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
147+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
148+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding (ggml_metal_library_t lib, const struct ggml_tensor * op);
149+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
150+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
151+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
152+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
153+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss (ggml_metal_library_t lib, const struct ggml_tensor * op);
154+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op);
153155

154156
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
155157
ggml_metal_library_t lib,

src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10361036
op->type == GGML_TYPE_I64;
10371037
case GGML_OP_ARGMAX:
10381038
return has_simdgroup_reduction;
1039+
case GGML_OP_CROSS_ENTROPY_LOSS:
1040+
return has_simdgroup_reduction;
1041+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
1042+
return has_simdgroup_reduction;
10391043
case GGML_OP_NORM:
10401044
case GGML_OP_RMS_NORM:
10411045
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));

src/ggml-metal/ggml-metal-impl.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,16 @@ typedef struct {
942942
uint64_t nb01;
943943
} ggml_metal_kargs_argmax;
944944

945+
typedef struct {
946+
int32_t nclasses;
947+
int32_t nrows;
948+
} ggml_metal_kargs_cross_entropy_loss;
949+
950+
typedef struct {
951+
int32_t nclasses;
952+
int32_t nrows;
953+
} ggml_metal_kargs_cross_entropy_loss_back;
954+
945955
typedef struct {
946956
int64_t np;
947957
} ggml_metal_kargs_opt_step_adamw;

0 commit comments

Comments
 (0)