Skip to content

Commit 1a13964

Browse files
authored
metal : add cumsum (ggml-org#17305)
1 parent 2376b77 commit 1a13964

File tree

9 files changed

+406
-79
lines changed

9 files changed

+406
-79
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
318318
return res;
319319
}
320320

321+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
322+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
323+
324+
char base[256];
325+
char name[256];
326+
327+
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
328+
snprintf(name, 256, "%s", base);
329+
330+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
331+
if (res) {
332+
return res;
333+
}
334+
335+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
336+
337+
return res;
338+
}
339+
340+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
341+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
342+
343+
char base[256];
344+
char name[256];
345+
346+
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
347+
snprintf(name, 256, "%s", base);
348+
349+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
350+
if (res) {
351+
return res;
352+
}
353+
354+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355+
356+
return res;
357+
}
358+
321359
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
322360
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
323361

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
113113
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
114114
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
115115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
116+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
117+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
116118
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
117119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
118120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
870870
case GGML_OP_SUM:
871871
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
872872
case GGML_OP_SUM_ROWS:
873+
case GGML_OP_CUMSUM:
873874
case GGML_OP_MEAN:
874875
case GGML_OP_SOFT_MAX:
875876
case GGML_OP_GROUP_NORM:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,45 @@ typedef struct {
612612
uint64_t nb3;
613613
} ggml_metal_kargs_sum_rows;
614614

615+
typedef struct {
616+
int64_t ne00;
617+
int64_t ne01;
618+
int64_t ne02;
619+
int64_t ne03;
620+
uint64_t nb00;
621+
uint64_t nb01;
622+
uint64_t nb02;
623+
uint64_t nb03;
624+
int64_t net0;
625+
int64_t net1;
626+
int64_t net2;
627+
int64_t net3;
628+
uint64_t nbt0;
629+
uint64_t nbt1;
630+
uint64_t nbt2;
631+
uint64_t nbt3;
632+
bool outb;
633+
} ggml_metal_kargs_cumsum_blk;
634+
635+
typedef struct {
636+
int64_t ne00;
637+
int64_t ne01;
638+
int64_t ne02;
639+
int64_t ne03;
640+
uint64_t nb00;
641+
uint64_t nb01;
642+
uint64_t nb02;
643+
uint64_t nb03;
644+
int64_t net0;
645+
int64_t net1;
646+
int64_t net2;
647+
int64_t net3;
648+
uint64_t nbt0;
649+
uint64_t nbt1;
650+
uint64_t nbt2;
651+
uint64_t nbt3;
652+
} ggml_metal_kargs_cumsum_add;
653+
615654
typedef struct {
616655
int32_t ne00;
617656
int32_t ne01;

0 commit comments

Comments
 (0)