diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0b10e5f6a..c747c1c80 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,6 +461,56 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 000000000..030397d40 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,135 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) +// Each block processes one row (ne[0] elements) +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + if (tid >= ne00) { + return; + } + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + // Load value and compute prefix sum within warp + float val = static_cast(src_row[i0]); + val = warp_prefix_inclusive_sum(val); + dst_row[i0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { + const int shmem_idx = i0 / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + const int shmem_idx = i0 / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + dim3 grid_dims(ne01, ne02, ne03); + + // Shared memory size: one float per warp + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + int block_size = num_warps * WARP_SIZE; + if (block_size > CUDA_CUMSUM_BLOCK_SIZE) { + block_size = CUDA_CUMSUM_BLOCK_SIZE; + } + dim3 block_dims(block_size, 1, 1); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3 + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 000000000..782d1d92e --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a844a3d99..689e5dfc3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -54,6 +54,8 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml.h" #include @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 000000000..b531f6963 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,104 @@ +#include "tri.cuh" +#include "ggml.h" + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = tri_compare(i0, i1, ttype) + ? src_row[i0] : static_cast(0.f); + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + ttype + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 000000000..a4cc66750 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9..306fa15b9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7938,6 +7938,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {