@@ -1853,32 +1853,54 @@ kernel void kernel_cumsum(
18531853 ushort sgitg[[simdgroup_index_in_threadgroup]],
18541854 ushort tiisg[[thread_index_in_simdgroup]],
18551855 ushort3 ntg[[threads_per_threadgroup]]) {
1856- const int64_t i3 = tgpig.z ;
1857- const int64_t i2 = tgpig.y ;
1858- const int64_t i1 = tgpig.x ;
18591856
1860- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 ) {
1857+ // Figure out the dize and stride of the cumsum dim
1858+ const int64_t ne_dim = (args.dim == 0 ) ? args.ne00 : (args.dim == 1 ) ? args.ne01 : (args.dim == 2 ) ? args.ne02 : args.ne03 ;
1859+ const int64_t nb_dim_src = (args.dim == 0 ) ? args.nb00 : (args.dim == 1 ) ? args.nb01 : (args.dim == 2 ) ? args.nb02 : args.nb03 ;
1860+ const int64_t nb_dim_dst = (args.dim == 0 ) ? args.nb0 : (args.dim == 1 ) ? args.nb1 : (args.dim == 2 ) ? args.nb2 : args.nb3 ;
1861+
1862+ // Map threadgroup indices to actual tensor dimensions
1863+ // tgpig.x, tgpig.y, tgpig.z represent the 3 non-cumsum dimensions
1864+ // tpitg.x represents position in the cumsum dimension
1865+ int64_t grid_indices[3 ] = {int64_t (tgpig.x ), int64_t (tgpig.y ), int64_t (tgpig.z )};
1866+ int64_t i_vals[4 ];
1867+
1868+ int grid_idx = 0 ;
1869+ for (int d = 0 ; d < 4 ; ++d) {
1870+ if (d == args.dim ) {
1871+ i_vals[d] = 0 ; // Will be set in the loop below
1872+ } else {
1873+ i_vals[d] = grid_indices[grid_idx++];
1874+ }
1875+ }
1876+
1877+ // Base index offsets. The cumsum dim will be further offset by the position
1878+ // in the threadgroup
1879+ const int64_t i0 = i_vals[0 ];
1880+ const int64_t i1 = i_vals[1 ];
1881+ const int64_t i2 = i_vals[2 ];
1882+ const int64_t i3 = i_vals[3 ];
1883+
1884+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 || i0 >= args.ne00 ) {
18611885 return ;
18621886 }
18631887
1864- device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 );
1865- device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 );
1888+ // Each thread processes elements at stride ntg.x along the cumsum dimension
1889+ for (int64_t i_dim = tpitg.x ; i_dim < ne_dim; i_dim += ntg.x ) {
1890+ const int64_t offset_src = i0*args.nb00 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 + i_dim*nb_dim_src;
1891+ const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst;
18661892
1867- // Each thread is a single element of the row if ne00 < max threads per
1868- // threadgroup, so this will loop once for each index that this thread is
1869- // responsible for
1870- for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1893+ device const T * src_ptr = (device const T *) ((device const char *) src0 + offset_src);
1894+ device T * dst_ptr = (device T *) ((device char *) dst + offset_dst);
18711895
1872- // Each thread does simd_prefix_inclusive_sum => every element of row
1873- // now holds cumsum of the simd group
1874- float sumf = static_cast <float >(src_row[i0]);
1896+ // Each thread does simd_prefix_inclusive_sum
1897+ float sumf = static_cast <float >(src_ptr[0 ]);
18751898 sumf = simd_prefix_inclusive_sum (sumf);
1876- dst_row[i0 ] = static_cast <T>(sumf);
1899+ dst_ptr[ 0 ] = static_cast <T>(sumf);
18771900
1878- // If this is the last element of the simd group, store its value in
1879- // shared memory
1880- if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1 ) {
1881- const ushort shmem_idx = i0 / N_SIMDWIDTH;
1901+ // If this is the last element of the simd group, store its value in shared memory
1902+ if (tiisg == N_SIMDWIDTH - 1 || i_dim == ne_dim - 1 ) {
1903+ const ushort shmem_idx = i_dim / N_SIMDWIDTH;
18821904 shmem_f32[shmem_idx] = sumf;
18831905 }
18841906 }
@@ -1887,10 +1909,13 @@ kernel void kernel_cumsum(
18871909 threadgroup_barrier (mem_flags::mem_threadgroup);
18881910
18891911 // Each element then adds the final value of all preceding simd groups
1890- for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1891- const ushort shmem_idx = i0 / N_SIMDWIDTH;
1912+ for (int64_t i_dim = tpitg.x ; i_dim < ne_dim; i_dim += ntg.x ) {
1913+ const int64_t offset_dst = i0*args.nb0 + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 + i_dim*nb_dim_dst;
1914+ device T * dst_ptr = (device T *) ((device char *) dst + offset_dst);
1915+
1916+ const ushort shmem_idx = i_dim / N_SIMDWIDTH;
18921917 for (ushort j = 0 ; j < shmem_idx; ++j) {
1893- dst_row[i0 ] += static_cast <T>(shmem_f32[j]);
1918+ dst_ptr[ 0 ] += static_cast <T>(shmem_f32[j]);
18941919 }
18951920 }
18961921}
0 commit comments