diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index f31989c8c55c6..1bf8e705e359c 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -184,8 +184,9 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { const char * name = gguf_get_tensor_name (ctx, i); const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); + const char * type = ggml_type_name(gguf_get_tensor_type(ctx, i)); - printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s\n", __func__, i, name, size, offset, type); } } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 60c6b63d05978..beb7ee988097a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -474,6 +474,7 @@ extern "C" { GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, + GGML_OP_CUMSUM, GGML_OP_MEAN, GGML_OP_ARGMAX, GGML_OP_COUNT_EQUAL, @@ -529,6 +530,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_TRI, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -615,6 +617,13 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + enum ggml_tri_type { + GGML_TRI_TYPE_UPPER_DIAG = 0, + GGML_TRI_TYPE_UPPER = 1, + GGML_TRI_TYPE_LOWER_DIAG = 2, + GGML_TRI_TYPE_LOWER = 3 + }; + struct ggml_init_params { // memory pool size_t mem_size; // bytes @@ -978,6 +987,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx, @@ -2141,6 +2154,17 @@ extern "C" { int shift2, int shift3); + // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value + GGML_API struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype); + + GGML_API struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype); // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ba2a36d999128..9bd36ea8cacef 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1736,6 +1736,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_sum_rows(params, tensor); } break; + case GGML_OP_CUMSUM: + { + ggml_compute_forward_cumsum(params, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor); @@ -1948,6 +1952,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_TRI: + { + ggml_compute_forward_tri(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor); @@ -2158,6 +2166,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 1c43865ff65fc..4aaf08aa2ea6a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,6 +9,7 @@ #include #include +#include // ggml_compute_forward_dup @@ -1394,6 +1395,127 @@ void ggml_compute_forward_sum( } } +// ggml_compute_forward_cumsum + +static void ggml_compute_forward_cumsum_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f32(ne00, dst_row, src_row); + } + } + } +} + +static void ggml_compute_forward_cumsum_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + ggml_fp16_t * src_row = (ggml_fp16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_fp16_t * dst_row = (ggml_fp16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_f16(ne00, dst_row, src_row); + } + } + } +} + +static void ggml_compute_forward_cumsum_bf16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_bf16_t)); + GGML_ASSERT(dst->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + ggml_bf16_t * src_row = (ggml_bf16_t *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + ggml_bf16_t * dst_row = (ggml_bf16_t *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_cumsum_bf16(ne00, dst_row, src_row); + } + } + } +} + +void ggml_compute_forward_cumsum( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cumsum_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_cumsum_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_cumsum_bf16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum_rows static void ggml_compute_forward_sum_rows_f32( @@ -2140,6 +2262,112 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_tri + +static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); + + // TODO: Is ggml_is_contiguous_rows safe and sufficient? + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + float * dst_ptr = (float *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src = (float *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype); + } + +} + +static void ggml_compute_forward_tri_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); + + // TODO: Is ggml_is_contiguous_rows safe and sufficient? + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_f16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_FP16(c), ttype); + } + +} + +static void ggml_compute_forward_tri_bf16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + const ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0]; + const float c = ggml_get_op_params_f32(dst, 1); + const bool keep_org_val = isnan(c); + + // TODO: Is ggml_is_contiguous_rows safe and sufficient? + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + + GGML_TENSOR_UNARY_OP_LOCALS + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *)((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_bf16_t * src = (ggml_bf16_t *)((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 ); + ggml_vec_tri_bf16(ne0, i01, dst_ptr, src, keep_org_val, GGML_FP32_TO_BF16(c), ttype); + } + +} + +void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tri_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_tri_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_tri_bf16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_erf static void ggml_compute_forward_gelu_erf_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b45833..d6f8dedcd9c55 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 65c7dfb6b9a49..94031a1b01008 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1404,6 +1404,8 @@ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const } } +// sum + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -1440,6 +1442,80 @@ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16 *s = sum; } +// tri + +// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'. +// Parameters: +// n - number of elements +// r - current row index +// dst - output array +// src - input array +// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c' +// c - constant value to use when not keeping original value +// type - type of triangular mask (lower, upper, etc.) +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const enum ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; break; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; break; + case GGML_TRI_TYPE_UPPER: return i > r; break; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; break; + default: GGML_ABORT("Invalid tri type"); + } +} + +inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0.0f; + } +} + +inline static void ggml_vec_tri_f16(const int n, const int r, ggml_fp16_t * dst, const ggml_fp16_t * src, bool keep_org_val, ggml_fp16_t c, enum ggml_tri_type type) { + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : 0; + } +} + +inline static void ggml_vec_tri_bf16(const int n, const int r, ggml_bf16_t * dst, const ggml_bf16_t * src, bool keep_org_val, ggml_bf16_t c, enum ggml_tri_type type) { + const ggml_bf16_t zero = ggml_fp32_to_bf16(0); + for (int i = 0; i < n; ++i) { + dst[i] = _ggml_vec_tri_cmp(i, r, type) ? (keep_org_val ? src[i] : c) : zero; + } +} + +// cumsum + +inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = y[i - 1] + x[i]; + } + } +} + +inline static void ggml_vec_cumsum_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i - 1]) + GGML_CPU_FP16_TO_FP32(x[i])); + } + } +} + +inline static void ggml_vec_cumsum_bf16(const int n, ggml_bf16_t * y, const ggml_bf16_t * x) { + for (int i = 0; i < n; ++i) { + if (i == 0) { + y[i] = x[i]; + } else { + y[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(y[i - 1]) + GGML_BF16_TO_FP32(x[i])); + } + } +} + +// max + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -1452,6 +1528,8 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #endif } +// norm inv + inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1.f/(*s); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6922..199d8f9debc07 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -406,6 +406,56 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + template static __device__ __forceinline__ int warp_reduce_all(int x) { if (width == ggml_cuda_get_physical_warp_size()) { diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 0000000000000..e14be0721c699 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,126 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) +// Each block processes one row (ne[0] elements) +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + // Load value and compute prefix sum within warp + float val = static_cast(src_row[i0]); + val = warp_prefix_inclusive_sum(val); + dst_row[i0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { + const int shmem_idx = i0 / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + const int shmem_idx = i0 / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + // Shared memory size: one float per warp + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3 + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 0000000000000..782d1d92e9bb1 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db14c514..7e11108c9684f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" @@ -46,6 +47,7 @@ #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" +#include "ggml-cuda/tri.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -2512,6 +2514,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -3650,6 +3658,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 0000000000000..d9c4aa025dbaf --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,109 @@ +#include "tri.cuh" +#include "ggml.h" +#include + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const bool keep_org_val = isnan(c); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = tri_compare(i0, i1, ttype) + ? (keep_org_val ? src_row[i0] : static_cast(c)) + : static_cast(0.f); + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const float c, const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + c, ttype + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + const float c = ggml_get_op_params_f32(dst, 1); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + c, ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 0000000000000..a4cc66750d3b5 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 866cd2da58576..41519ba4e51a2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -318,6 +318,52 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "cumsum"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + // reuse existing precompiled pipeline, but allow memory size setting + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (!res) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + // one shared memory element for each simd group in the threadgroup + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + const int nsg = (ne00 + 31)/32; + ggml_metal_pipeline_set_smem(res, nsg*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "tri"; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + return ggml_metal_library_compile_pipeline(lib, base, name, nullptr); +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 28ae2e1765146..a8db4fa4e873a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -111,6 +111,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3c83abe4e63e..e93ad6534e2cc 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -661,6 +661,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fa2d82cefb40e..8cb22668857b4 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -568,6 +568,46 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_sum_rows; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cumsum; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float c; + uint32_t ttype; +} ggml_metal_kargs_tri; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 4f9f6bda00a79..03520b7c29f77 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -310,6 +310,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_sum_rows(ctx, idx); } break; + case GGML_OP_CUMSUM: + { + n_fuse = ggml_metal_op_cumsum(ctx, idx); + } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -952,6 +960,117 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_cumsum args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cumsum(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0]; + const float c = *((float *) &(op->op_params[1])); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.c =*/ c, + /*.ttype =*/ static_cast(ttype) + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f352738698beb..ea2ee5337326a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -52,6 +52,8 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 496610b154b6d..e4c9579981dba 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1826,6 +1826,117 @@ typedef decltype(kernel_sum_rows) kernel_sum_rows_t; template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template +kernel void kernel_cumsum( + constant ggml_metal_kargs_cumsum & args, + device const char * src0, + device const char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + + // Each thread does simd_prefix_inclusive_sum => every element of row + // now holds cumsum of the simd group + float sumf = static_cast(src_row[i0]); + sumf = simd_prefix_inclusive_sum(sumf); + dst_row[i0] = static_cast(sumf); + + // If this is the last element of the simd group, store its value in + // shared memory + if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1) { + const ushort shmem_idx = i0 / N_SIMDWIDTH; + shmem_f32[shmem_idx] = sumf; + } + } + + // Ensure all simd groups sync here before proceeding + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each element then adds the final value of all preceding simd groups + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + const ushort shmem_idx = i0 / N_SIMDWIDTH; + for (ushort j = 0; j < shmem_idx; ++j) { + dst_row[i0] += static_cast(shmem_f32[j]); + } + } +} + +typedef decltype(kernel_cumsum) kernel_cumsum_t; + +template [[host_name("kernel_cumsum_f32")]] kernel kernel_cumsum_t kernel_cumsum; +template [[host_name("kernel_cumsum_f16")]] kernel kernel_cumsum_t kernel_cumsum; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cumsum_bf16")]] kernel kernel_cumsum_t kernel_cumsum; +#endif + +inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) { + switch (type) { + // ggml.h:620 + case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break; + case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break; + case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break; + case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break; + } +} + +template +kernel void kernel_tri( + constant ggml_metal_kargs_tri & args, + device const char * src0, + device const char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + // Each thread is a single element of the row if ne00 < max threads per + // threadgroup, so this will loop once for each index that this thread is + // responsible for + const bool keep_org_val = isnan(args.c); + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) + ? (keep_org_val ? src_row[i0] : static_cast(args.c)) + : static_cast(0.f); + } +} + +typedef decltype(kernel_tri) kernel_tri_t; + +template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri; +template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri; +#endif + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bce1375ba3c0..7e4bfb07154b3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "COS", "SUM", "SUM_ROWS", + "CUMSUM", "MEAN", "ARGMAX", "COUNT_EQUAL", @@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "TRI", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -1019,7 +1021,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cos(x)", "Σx", "Σx_k", + "cumsum(x)", "Σx/n", "argmax(x)", "count_equal(x)", @@ -1094,6 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "tri(x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -1123,7 +1127,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2337,6 +2341,20 @@ struct ggml_tensor * ggml_sum_rows( return result; } +// ggml_cumsum + +struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_CUMSUM; + result->src[0] = a; + + return result; +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -4968,6 +4986,33 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } +// ggml_tri + +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + float constant, + enum ggml_tri_type tritype) { + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, tritype); + ggml_set_op_params_f32(result, 1, constant); + + result->op = GGML_OP_TRI; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_tri_keep( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type tritype) { + + return ggml_tri(ctx, a, nan(""), tritype); +} + // ggml_argsort struct ggml_tensor * ggml_argsort( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f29a1e98c9103..6d1a6cc24595c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -458,8 +476,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; } // @@ -1808,6 +1864,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1876,10 +1935,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67927..caba9779b5d48 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph @@ -232,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -364,22 +370,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01bdf..a1b45e4a3cce3 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a2876c808ffc..f272a39e57a0d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -175,6 +175,33 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); } +static std::vector ggml_get_float_value(uint8_t * buf, ggml_type type, size_t i, size_t bs, + bool quantized, std::vector & vq) { + const auto * tt = ggml_get_type_traits(type); + std::vector tv; + if (type == GGML_TYPE_F16) { + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (type == GGML_TYPE_BF16) { + tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); + } else if (type == GGML_TYPE_F32) { + tv.push_back(*(float *) &buf[i]); + } else if (type == GGML_TYPE_I64) { + tv.push_back((float)*(int64_t *) &buf[i]); + } else if (type == GGML_TYPE_I32) { + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); + } else if (quantized) { + tt->to_float(&buf[i], vq.data(), bs); + tv.insert(tv.end(), vq.begin(), vq.end()); + } else { + GGML_ABORT("fatal error"); + } + return tv; +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -182,7 +209,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { std::vector buf(ggml_nbytes(t)); ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); - const auto * tt = ggml_get_type_traits(t->type); size_t bs = ggml_blck_size(t->type); std::vector vq(ggml_blck_size(t->type)); bool quantized = ggml_is_quantized(t->type); @@ -193,26 +219,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; - if (t->type == GGML_TYPE_F16) { - tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_BF16) { - tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); - } else if (t->type == GGML_TYPE_F32) { - tv.push_back(*(float *) &buf[i]); - } else if (t->type == GGML_TYPE_I64) { - tv.push_back((float)*(int64_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I32) { - tv.push_back((float)*(int32_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I16) { - tv.push_back((float)*(int16_t *) &buf[i]); - } else if (t->type == GGML_TYPE_I8) { - tv.push_back((float)*(int8_t *) &buf[i]); - } else if (quantized) { - tt->to_float(&buf[i], vq.data(), bs); - tv.insert(tv.end(), vq.begin(), vq.end()); - } else { - GGML_ABORT("fatal error"); - } + const auto fvs = ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq); + tv.insert(tv.end(), fvs.begin(), fvs.end()); } } } @@ -221,6 +229,107 @@ static std::vector tensor_to_float(const ggml_tensor * t) { return tv; } +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void ggml_print_tensor(ggml_tensor * t, int64_t n = 3) { + GGML_ASSERT(t != nullptr); + GGML_ASSERT(n > 0); + + std::stringstream src_ss; + src_ss << "("; + size_t last_src = 0; + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + last_src = i; + } + } + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (t->src[i] != nullptr) { + src_ss << t->src[i]->name << "{" << ggml_ne_string(t->src[i]) <<"}"; + } + if (i <= last_src) { + src_ss << ", "; + } + } + src_ss << ")"; + + printf("%s: %24s = (%s) %10s%s = {%s}\n", __func__, + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src_ss.str().c_str(), + ggml_ne_string(t).c_str()); + + std::vector tv; + tv.reserve(ggml_nelements(t)); + + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + + size_t bs = ggml_blck_size(t->type); + std::vector vq(ggml_blck_size(t->type)); + bool quantized = ggml_is_quantized(t->type); + + float sum = 0; + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + for (const auto & val : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + sum += val; + } + } + } + } + } + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + printf(" [\n"); + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + if (i2 == n && t->ne[2] > 2*n) { + printf(" ..., \n"); + i2 = t->ne[2] - n; + } + printf(" [\n"); + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + if (i1 == n && t->ne[1] > 2*n) { + printf(" ..., \n"); + i1 = t->ne[1] - n; + } + printf(" ["); + for (int64_t i0 = 0; i0 < t->ne[0]; i0++) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + if (i0 == n && t->ne[0] > 2*n) { + printf("..., "); + i0 = t->ne[0] - n; + } + for (const auto & v : ggml_get_float_value(buf.data(), t->type, i, bs, quantized, vq)) { + printf("%12.4f", v); + } + if (i0 < t->ne[0] - 1) printf(", "); + } + printf("],\n"); + } + printf(" ],\n"); + } + printf(" ]\n"); + printf(" sum = %f\n", sum); + } + + // TODO: make this abort configurable/optional? + if (std::isnan(sum)) { + printf("encountered NaN - aborting\n"); + exit(0); + } +} + // normalized mean squared error = mse(a, b) / mse(a, 0) static double nmse(const float * a, const float * b, size_t n) { double mse_a_b = 0.0; @@ -980,6 +1089,8 @@ static std::unique_ptr create_printer(output_formats format) { GGML_ABORT("invalid output format"); } +// test case definition + struct test_case { virtual ~test_case() {} @@ -1056,6 +1167,9 @@ struct test_case { std::vector sentinels; + // set to 1 to print tensors, 2 to fully print tensors + int verbose = 0; + void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { return; @@ -1201,6 +1315,7 @@ struct test_case { // compare struct callback_userdata { bool ok; + int verbose; double max_err; ggml_backend_t backend1; ggml_backend_t backend2; @@ -1208,6 +1323,7 @@ struct test_case { callback_userdata ud { true, + verbose, max_nmse_err(), backend1, backend2 @@ -1232,6 +1348,11 @@ struct test_case { } } + if (ud->verbose) { + ggml_print_tensor(t1, ud->verbose >= 2 ? 1e10 : 3); + ggml_print_tensor(t2, ud->verbose >= 2 ? 1e10 : 3); + } + std::vector f1 = tensor_to_float(t1); std::vector f2 = tensor_to_float(t2); @@ -1261,11 +1382,12 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); - //exit(1); + if (ud->verbose) { + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); + } ud->ok = false; } return true; @@ -4661,6 +4783,69 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_CUMSUM +struct test_cumsum : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cumsum(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + + double max_nmse_err() override { + // Lower precision types have expected precision errors in lower bits + if (type == GGML_TYPE_BF16 || type == GGML_TYPE_F16) { + return 1e-5; + } + return 1e-7; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_cumsum(ctx, a); + ggml_set_name(out, "out"); + + return out; + } +}; + +// GGML_OP_TRI +struct test_tri : public test_case { + const ggml_type type; + const std::array ne; + const ggml_tri_type tri_type; + const float c; + + std::string vars() override { + return VARS_TO_STR4(type, ne, tri_type, c); + } + + test_tri(ggml_tri_type tri_type, + ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 1, 1}, + float c = nan("")) + : type(type), ne(ne), tri_type(tri_type), c(c) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_tri(ctx, a, c, tri_type); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_MEAN struct test_mean : public test_case { const ggml_type type; @@ -5780,7 +5965,7 @@ static const ggml_type other_types[] = { }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low -static std::vector> make_test_cases_eval() { +static std::vector> make_test_cases_eval(int verbose = 0) { std::vector> test_cases; std::default_random_engine rng(0); @@ -6766,6 +6951,23 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); @@ -6839,6 +7041,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_falcon(2)); #endif + // set verbose on all test cases + for (auto & tc : test_cases) { + tc->verbose = verbose; + } + return test_cases; } @@ -6917,6 +7124,23 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_BF16, { 4, 2, 2, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2025, 5, 6, 3 })); + + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_BF16)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16})); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f)); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1})); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { @@ -7001,7 +7225,7 @@ static std::vector> make_test_cases_perf() { } static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer) { + printer * output_printer, int verbose) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -7020,7 +7244,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op }; if (mode == MODE_TEST) { - auto test_cases = make_test_cases_eval(); + auto test_cases = make_test_cases_eval(verbose); filter_test_cases(test_cases, params_filter); ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); if (backend_cpu == NULL) { @@ -7182,6 +7406,7 @@ static void usage(char ** argv) { printf(" --output specifies output format (default: console, options: console, sql, csv)\n"); printf(" --list-ops lists all available GGML operations\n"); printf(" --show-coverage shows test coverage\n"); + printf(" --verbose | -v print tensors during ops (can specify multiple times)\n"); } int main(int argc, char ** argv) { @@ -7190,6 +7415,7 @@ int main(int argc, char ** argv) { const char * op_names_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; + int verbose = 0; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { @@ -7237,6 +7463,8 @@ int main(int argc, char ** argv) { } else if (strcmp(argv[i], "--show-coverage") == 0) { show_test_coverage(); return 0; + } else if (strcmp(argv[i], "--verbose") == 0 || strcmp(argv[i], "-v") == 0) { + ++verbose; } else { usage(argv); return 1; @@ -7289,7 +7517,7 @@ int main(int argc, char ** argv) { false, "", ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024, true)); - bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get()); + bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), verbose); if (ok) { n_ok++;