|
| 1 | +#include "cumsum.cuh" |
| 2 | + |
| 3 | +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) |
| 4 | +// Each block processes one row (ne[0] elements) |
| 5 | +// Algorithm matches Metal implementation: |
| 6 | +// 1. Each warp computes prefix sum within itself |
| 7 | +// 2. Last thread of each warp stores result in shared memory |
| 8 | +// 3. All warps sync |
| 9 | +// 4. Each element adds the sum of all preceding warps |
| 10 | + |
| 11 | +template<typename T> |
| 12 | +static __global__ void cumsum_kernel( |
| 13 | + const T * src, T * dst, |
| 14 | + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 15 | + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, |
| 16 | + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { |
| 17 | + |
| 18 | + // Shared memory to store warp sums (always use float for accumulation) |
| 19 | + extern __shared__ float shmem[]; |
| 20 | + |
| 21 | + const int64_t i3 = blockIdx.z; |
| 22 | + const int64_t i2 = blockIdx.y; |
| 23 | + const int64_t i1 = blockIdx.x; |
| 24 | + |
| 25 | + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { |
| 26 | + return; |
| 27 | + } |
| 28 | + |
| 29 | + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); |
| 30 | + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); |
| 31 | + |
| 32 | + const int tid = threadIdx.x; |
| 33 | + const int lane_id = tid % WARP_SIZE; |
| 34 | + |
| 35 | + // Phase 1: Each thread processes elements at stride blockDim.x |
| 36 | + // Compute warp-level prefix sums |
| 37 | + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { |
| 38 | + // Load value and compute prefix sum within warp |
| 39 | + float val = static_cast<float>(src_row[i0]); |
| 40 | + val = warp_prefix_inclusive_sum(val); |
| 41 | + dst_row[i0] = static_cast<T>(val); |
| 42 | + |
| 43 | + // Last thread of warp stores its sum to shared memory at position based on data index |
| 44 | + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { |
| 45 | + const int shmem_idx = i0 / WARP_SIZE; |
| 46 | + shmem[shmem_idx] = val; |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + // Sync once after all warp prefix sums are computed |
| 51 | + __syncthreads(); |
| 52 | + |
| 53 | + // Phase 2: Add the sum of all preceding warp groups to each element |
| 54 | + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { |
| 55 | + const int shmem_idx = i0 / WARP_SIZE; |
| 56 | + float sum = 0.0f; |
| 57 | + for (int j = 0; j < shmem_idx; ++j) { |
| 58 | + sum += shmem[j]; |
| 59 | + } |
| 60 | + dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum); |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +template<typename T> |
| 65 | +static void cumsum_cuda( |
| 66 | + const T * src, T * dst, |
| 67 | + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 68 | + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, |
| 69 | + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, |
| 70 | + cudaStream_t stream) { |
| 71 | + |
| 72 | + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); |
| 73 | + dim3 grid_dims(ne01, ne02, ne03); |
| 74 | + |
| 75 | + // Shared memory size: one float per warp |
| 76 | + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; |
| 77 | + const size_t shmem_size = num_warps * sizeof(float); |
| 78 | + |
| 79 | + cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>( |
| 80 | + src, dst, |
| 81 | + ne00, ne01, ne02, ne03, |
| 82 | + nb00, nb01, nb02, nb03, |
| 83 | + nb0, nb1, nb2, nb3 |
| 84 | + ); |
| 85 | +} |
| 86 | + |
| 87 | +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 88 | + const ggml_tensor * src0 = dst->src[0]; |
| 89 | + cudaStream_t stream = ctx.stream(); |
| 90 | + |
| 91 | + GGML_ASSERT(src0->type == dst->type); |
| 92 | + switch(src0->type) { |
| 93 | + case GGML_TYPE_F32: |
| 94 | + { |
| 95 | + cumsum_cuda( |
| 96 | + (const float *)src0->data, (float *)dst->data, |
| 97 | + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], |
| 98 | + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
| 99 | + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], |
| 100 | + stream |
| 101 | + ); |
| 102 | + } break; |
| 103 | + case GGML_TYPE_F16: |
| 104 | + { |
| 105 | + cumsum_cuda( |
| 106 | + (const half *)src0->data, (half *)dst->data, |
| 107 | + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], |
| 108 | + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
| 109 | + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], |
| 110 | + stream |
| 111 | + ); |
| 112 | + } break; |
| 113 | + case GGML_TYPE_BF16: |
| 114 | + { |
| 115 | + cumsum_cuda( |
| 116 | + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, |
| 117 | + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], |
| 118 | + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
| 119 | + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], |
| 120 | + stream |
| 121 | + ); |
| 122 | + } break; |
| 123 | + default: |
| 124 | + GGML_ABORT("fatal error"); |
| 125 | + } |
| 126 | +} |
0 commit comments