diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c7a096..2cc96a9051970 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -554,6 +554,8 @@ extern "C" { GGML_OP_GLU, + GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX, + GGML_OP_COUNT, }; @@ -1627,6 +1629,12 @@ extern "C" { struct ggml_tensor * b, float scale, float max_bias); + // fused soft_max with diag mask inf + GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace( + struct ggml_context * ctx, + float scale, + int n_past, + struct ggml_tensor * a); // rotary position embedding // if (mode & 1) - skip n_past elements (NOT SUPPORTED) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d5d3a3440aaf..f48165d470ebf 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1393,6 +1393,67 @@ UseGgmlGemm2:; } } +static void ggml_compute_forward_scale_diag_mask_inf_softmax( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_contiguous(dst)); + + float * dst_ptr = (float *) dst->data; + float * src0_ptr = (float *) src0->data; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int32_t * params_ptr = (const int32_t *) dst->op_params; + const float scale = *(const float *) ¶ms_ptr[0]; + const int n_past = params_ptr[1]; + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float max = -INFINITY; + float sum = 0.0f; + + // Scale and apply diagonal mask + for (int64_t i0 = 0; i0 < ne00; i0++) { + const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0; + + float val = src0_ptr[idx] * scale; + + // Apply diagonal mask + if (i0 > n_past + i1) { + val = -INFINITY; + } + + dst_ptr[idx] = val; + max = MAX(max, val); + } + + // Compute softmax + for (int64_t i0 = 0; i0 < ne00; i0++) { + const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0; + + const float val = dst_ptr[idx]; + const float exp_val = expf(val - max); + dst_ptr[idx] = exp_val; + sum += exp_val; + } + + // Normalize + const float inv_sum = 1.0f/sum; + for (int64_t i0 = 0; i0 < ne00; i0++) { + const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0; + dst_ptr[idx] *= inv_sum; + } + } + } + } +} + // ggml_compute_forward_mul_mat_id #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)] @@ -2011,6 +2072,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_custom(params, tensor); } break; + case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX: + { + ggml_compute_forward_scale_diag_mask_inf_softmax(params, tensor, tensor); + } + break; case GGML_OP_CROSS_ENTROPY_LOSS: { ggml_compute_forward_cross_entropy_loss(params, tensor); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f789e2..93d7544a0cd8d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1010,6 +1010,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CUSTOM", + "SCALE_DIAG_MASK_INF_SOFTMAX", + "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", @@ -1018,7 +1020,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1113,6 +1115,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "custom(x)", + "scale_diag_mask_inf_softmax", + "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", @@ -1121,7 +1125,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3880,6 +3884,24 @@ struct ggml_tensor * ggml_soft_max_ext_back_inplace( return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); } +struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace( + struct ggml_context * ctx, + float scale, + int n_past, + struct ggml_tensor * a) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int32_t params[2]; + memcpy(¶ms[0], &scale, sizeof(scale)); + params[1] = n_past; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX; + result->src[0] = a; + + return result; +} + // ggml_rope static struct ggml_tensor * ggml_rope_impl( diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 91719577564a9..11ffaea074274 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -208,6 +208,7 @@ if (NOT GGML_BACKEND_DL) llama_build_and_test(test-quantize-fns.cpp) llama_build_and_test(test-quantize-perf.cpp) llama_build_and_test(test-rope.cpp) + llama_build_and_test(test-scale-diag-mask.cpp) endif() # libmtmd diff --git a/tests/test-scale-diag-mask.cpp b/tests/test-scale-diag-mask.cpp new file mode 100644 index 0000000000000..700e453f275c9 --- /dev/null +++ b/tests/test-scale-diag-mask.cpp @@ -0,0 +1,65 @@ +#include "ggml/ggml.h" +#include "common/common.h" + +#include +#include +#include +#include + +bool test_scale_diag_mask_inf_softmax() { + struct ggml_init_params params = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + // initialize context + struct ggml_context * ctx = ggml_init(params); + + // create test tensor (2x2 matrix) + struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 2); + + // fill with test values + float * data = (float *) x->data; + data[0] = 1.0f; // [1.0 2.0] + data[1] = 2.0f; // [3.0 4.0] + data[2] = 3.0f; + data[3] = 4.0f; + + // apply operation + float scale = 2.0f; + int n_past = 0; + struct ggml_tensor * y = ggml_scale_diag_mask_inf_softmax_inplace(ctx, scale, n_past, x); + + // compute + struct ggml_cgraph gf = ggml_build_forward(y); + ggml_graph_compute(ctx, &gf); + + // check results + float * result = (float *) y->data; + + // Expected values after scale=2.0, masking, and softmax: + // For first row: [exp(2)/sum, exp(4)/sum] where sum = exp(2) + exp(4) + // For second row: [exp(6)/sum, exp(8)/sum] where sum = exp(6) + exp(8) + const float eps = 1e-5f; + bool success = true; + + float sum1 = expf(2.0f) + expf(4.0f); + float sum2 = expf(6.0f) + expf(8.0f); + + success &= fabsf(result[0] - expf(2.0f)/sum1) < eps; + success &= fabsf(result[1] - expf(4.0f)/sum1) < eps; + success &= fabsf(result[2] - expf(6.0f)/sum2) < eps; + success &= fabsf(result[3] - expf(8.0f)/sum2) < eps; + + // cleanup + ggml_free(ctx); + + return success; +} + +int main(int argc, char ** argv) { + bool success = test_scale_diag_mask_inf_softmax(); + printf("%s: %s\n", __func__, success ? "PASSED" : "FAILED"); + return success ? 0 : 1; +}