Skip to content

Commit ee13af1

Browse files
committed
feat(ggml-metal): Support arbitrary dim and non-cont in cumsum
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d1e15c0 commit ee13af1

File tree

6 files changed

+68
-28
lines changed

6 files changed

+68
-28
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
320320
}
321321

322322
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_t lib, const ggml_tensor * op) {
323-
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
324-
325323
char base[256];
326324
char name[256];
327325

@@ -338,7 +336,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_
338336
}
339337

340338
// one shared memory element for each simd group in the threadgroup
341-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
339+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
342340
const int nsg = (ne00 + 31)/32;
343341
ggml_metal_pipeline_set_smem(res, nsg*sizeof(float));
344342

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
665665
case GGML_OP_TRI:
666666
return ggml_is_contiguous_rows(op->src[0]);
667667
case GGML_OP_CUMSUM:
668+
return has_simdgroup_reduction;
668669
case GGML_OP_SUM:
669670
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
670671
case GGML_OP_SUM_ROWS:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ typedef struct {
585585
uint64_t nb1;
586586
uint64_t nb2;
587587
uint64_t nb3;
588+
int32_t dim;
588589
} ggml_metal_kargs_cumsum;
589590

590591
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,8 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
971971
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
972972
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
973973

974+
const int32_t dim = (int32_t) op->op_params[0];
975+
974976
ggml_metal_kargs_cumsum args = {
975977
/*.ne00 =*/ ne00,
976978
/*.ne01 =*/ ne01,
@@ -988,18 +990,31 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
988990
/*.nb1 =*/ nb1,
989991
/*.nb2 =*/ nb2,
990992
/*.nb3 =*/ nb3,
993+
/*.dim =*/ dim
991994
};
992995

993996
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cumsum(lib, op);
994997

998+
// Dimension being accumulated
999+
const int64_t ne_dim = op->src[0]->ne[dim];
1000+
1001+
// Grid dimensions: the GGML_MAX_DIMS-1 non-cumsum dimensions
1002+
int64_t grid_dims[GGML_MAX_DIMS - 1];
1003+
int grid_idx = 0;
1004+
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
1005+
if (d != dim) {
1006+
grid_dims[grid_idx++] = op->src[0]->ne[d];
1007+
}
1008+
}
1009+
9951010
int nth = 32; // SIMD width
9961011

997-
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1012+
while (nth < ne_dim && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
9981013
nth *= 2;
9991014
}
10001015

10011016
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1002-
nth = std::min(nth, ne00);
1017+
nth = std::min(nth, (int)ne_dim);
10031018

10041019
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
10051020

@@ -1010,7 +1025,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
10101025

10111026
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
10121027

1013-
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1028+
ggml_metal_encoder_dispatch_threadgroups(enc, grid_dims[0], grid_dims[1], grid_dims[2], nth, 1, 1);
10141029

10151030
return 1;
10161031
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,32 +1853,54 @@ kernel void kernel_cumsum(
18531853
ushort sgitg[[simdgroup_index_in_threadgroup]],
18541854
ushort tiisg[[thread_index_in_simdgroup]],
18551855
ushort3 ntg[[threads_per_threadgroup]]) {
1856-
const int64_t i3 = tgpig.z;
1857-
const int64_t i2 = tgpig.y;
1858-
const int64_t i1 = tgpig.x;
18591856

1860-
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1857+
// Figure out the dize and stride of the cumsum dim
1858+
const int64_t ne_dim = (args.dim == 0) ? args.ne00 : (args.dim == 1) ? args.ne01 : (args.dim == 2) ? args.ne02 : args.ne03;
1859+
const int64_t nb_dim_src = (args.dim == 0) ? args.nb00 : (args.dim == 1) ? args.nb01 : (args.dim == 2) ? args.nb02 : args.nb03;
1860+
const int64_t nb_dim_dst = (args.dim == 0) ? args.nb0 : (args.dim == 1) ? args.nb1 : (args.dim == 2) ? args.nb2 : args.nb3;
1861+
1862+
// Map threadgroup indices to actual tensor dimensions
1863+
// tgpig.x, tgpig.y, tgpig.z represent the 3 non-cumsum dimensions
1864+
// tpitg.x represents position in the cumsum dimension
1865+
int64_t grid_indices[3] = {int64_t(tgpig.x), int64_t(tgpig.y), int64_t(tgpig.z)};
1866+
int64_t i_vals[4];
1867+
1868+
int grid_idx = 0;
1869+
for (int d = 0; d < 4; ++d) {
1870+
if (d == args.dim) {
1871+
i_vals[d] = 0; // Will be set in the loop below
1872+
} else {
1873+
i_vals[d] = grid_indices[grid_idx++];
1874+
}
1875+
}
1876+
1877+
// Base index offsets. The cumsum dim will be further offset by the position
1878+
// in the threadgroup
1879+
const int64_t i0 = i_vals[0];
1880+
const int64_t i1 = i_vals[1];
1881+
const int64_t i2 = i_vals[2];
1882+
const int64_t i3 = i_vals[3];
1883+
1884+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 || i0 >= args.ne00) {
18611885
return;
18621886
}
18631887

1864-
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1865-
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1888+
// Each thread processes elements at stride ntg.x along the cumsum dimension
1889+
for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) {
1890+
const int64_t offset_src = i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 + i_dim*nb_dim_src;
1891+
const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst;
18661892

1867-
// Each thread is a single element of the row if ne00 < max threads per
1868-
// threadgroup, so this will loop once for each index that this thread is
1869-
// responsible for
1870-
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1893+
device const T * src_ptr = (device const T *) ((device const char *) src0 + offset_src);
1894+
device T * dst_ptr = (device T *) ((device char *) dst + offset_dst);
18711895

1872-
// Each thread does simd_prefix_inclusive_sum => every element of row
1873-
// now holds cumsum of the simd group
1874-
float sumf = static_cast<float>(src_row[i0]);
1896+
// Each thread does simd_prefix_inclusive_sum
1897+
float sumf = static_cast<float>(src_ptr[0]);
18751898
sumf = simd_prefix_inclusive_sum(sumf);
1876-
dst_row[i0] = static_cast<T>(sumf);
1899+
dst_ptr[0] = static_cast<T>(sumf);
18771900

1878-
// If this is the last element of the simd group, store its value in
1879-
// shared memory
1880-
if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1) {
1881-
const ushort shmem_idx = i0 / N_SIMDWIDTH;
1901+
// If this is the last element of the simd group, store its value in shared memory
1902+
if (tiisg == N_SIMDWIDTH - 1 || i_dim == ne_dim - 1) {
1903+
const ushort shmem_idx = i_dim / N_SIMDWIDTH;
18821904
shmem_f32[shmem_idx] = sumf;
18831905
}
18841906
}
@@ -1887,10 +1909,13 @@ kernel void kernel_cumsum(
18871909
threadgroup_barrier(mem_flags::mem_threadgroup);
18881910

18891911
// Each element then adds the final value of all preceding simd groups
1890-
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1891-
const ushort shmem_idx = i0 / N_SIMDWIDTH;
1912+
for (int64_t i_dim = tpitg.x; i_dim < ne_dim; i_dim += ntg.x) {
1913+
const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst;
1914+
device T * dst_ptr = (device T *) ((device char *) dst + offset_dst);
1915+
1916+
const ushort shmem_idx = i_dim / N_SIMDWIDTH;
18921917
for (ushort j = 0; j < shmem_idx; ++j) {
1893-
dst_row[i0] += static_cast<T>(shmem_f32[j]);
1918+
dst_ptr[0] += static_cast<T>(shmem_f32[j]);
18941919
}
18951920
}
18961921
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4865,7 +4865,7 @@ struct test_cumsum : public test_case {
48654865
const std::array<int64_t, 4> permute;
48664866

48674867
std::string vars() override {
4868-
return VARS_TO_STR2(type, ne);
4868+
return VARS_TO_STR4(type, ne, dim, permute);
48694869
}
48704870

48714871
test_cumsum(ggml_type type = GGML_TYPE_F32,

0 commit comments

Comments
 (0)