Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ extern "C" {

GGML_OP_GLU,

GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX,

GGML_OP_COUNT,
};

Expand Down Expand Up @@ -1627,6 +1629,12 @@ extern "C" {
struct ggml_tensor * b,
float scale,
float max_bias);
// fused soft_max with diag mask inf
GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a);

// rotary position embedding
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
Expand Down
66 changes: 66 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,67 @@ UseGgmlGemm2:;
}
}

static void ggml_compute_forward_scale_diag_mask_inf_softmax(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_is_contiguous(dst));

float * dst_ptr = (float *) dst->data;
float * src0_ptr = (float *) src0->data;

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const int32_t * params_ptr = (const int32_t *) dst->op_params;
const float scale = *(const float *) &params_ptr[0];
const int n_past = params_ptr[1];

for (int64_t i3 = 0; i3 < ne03; i3++) {
for (int64_t i2 = 0; i2 < ne02; i2++) {
for (int64_t i1 = 0; i1 < ne01; i1++) {
float max = -INFINITY;
float sum = 0.0f;

// Scale and apply diagonal mask
for (int64_t i0 = 0; i0 < ne00; i0++) {
const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0;

float val = src0_ptr[idx] * scale;

// Apply diagonal mask
if (i0 > n_past + i1) {
val = -INFINITY;
}

dst_ptr[idx] = val;
max = MAX(max, val);
}

// Compute softmax
for (int64_t i0 = 0; i0 < ne00; i0++) {
const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0;

const float val = dst_ptr[idx];
const float exp_val = expf(val - max);
dst_ptr[idx] = exp_val;
sum += exp_val;
}

// Normalize
const float inv_sum = 1.0f/sum;
for (int64_t i0 = 0; i0 < ne00; i0++) {
const int64_t idx = i3*ne02*ne01*ne00 + i2*ne01*ne00 + i1*ne00 + i0;
dst_ptr[idx] *= inv_sum;
}
}
}
}
}

// ggml_compute_forward_mul_mat_id

#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
Expand Down Expand Up @@ -2011,6 +2072,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_custom(params, tensor);
}
break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
ggml_compute_forward_scale_diag_mask_inf_softmax(params, tensor, tensor);
}
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
ggml_compute_forward_cross_entropy_loss(params, tensor);
Expand Down
26 changes: 24 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {

"CUSTOM",

"SCALE_DIAG_MASK_INF_SOFTMAX",

"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",
Expand All @@ -1018,7 +1020,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1113,6 +1115,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {

"custom(x)",

"scale_diag_mask_inf_softmax",

"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",
Expand All @@ -1121,7 +1125,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};

static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -3880,6 +3884,24 @@ struct ggml_tensor * ggml_soft_max_ext_back_inplace(
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
}

struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a) {
struct ggml_tensor * result = ggml_view_tensor(ctx, a);

int32_t params[2];
memcpy(&params[0], &scale, sizeof(scale));
params[1] = n_past;
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX;
result->src[0] = a;

return result;
}

// ggml_rope

static struct ggml_tensor * ggml_rope_impl(
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ if (NOT GGML_BACKEND_DL)
llama_build_and_test(test-quantize-fns.cpp)
llama_build_and_test(test-quantize-perf.cpp)
llama_build_and_test(test-rope.cpp)
llama_build_and_test(test-scale-diag-mask.cpp)
endif()

# libmtmd
Expand Down
65 changes: 65 additions & 0 deletions tests/test-scale-diag-mask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "ggml/ggml.h"
#include "common/common.h"

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>

bool test_scale_diag_mask_inf_softmax() {
struct ggml_init_params params = {
.mem_size = 16*1024*1024,
.mem_buffer = NULL,
.no_alloc = false,
};

// initialize context
struct ggml_context * ctx = ggml_init(params);

// create test tensor (2x2 matrix)
struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 2);

// fill with test values
float * data = (float *) x->data;
data[0] = 1.0f; // [1.0 2.0]
data[1] = 2.0f; // [3.0 4.0]
data[2] = 3.0f;
data[3] = 4.0f;

// apply operation
float scale = 2.0f;
int n_past = 0;
struct ggml_tensor * y = ggml_scale_diag_mask_inf_softmax_inplace(ctx, scale, n_past, x);

// compute
struct ggml_cgraph gf = ggml_build_forward(y);
ggml_graph_compute(ctx, &gf);

// check results
float * result = (float *) y->data;

// Expected values after scale=2.0, masking, and softmax:
// For first row: [exp(2)/sum, exp(4)/sum] where sum = exp(2) + exp(4)
// For second row: [exp(6)/sum, exp(8)/sum] where sum = exp(6) + exp(8)
const float eps = 1e-5f;
bool success = true;

float sum1 = expf(2.0f) + expf(4.0f);
float sum2 = expf(6.0f) + expf(8.0f);

success &= fabsf(result[0] - expf(2.0f)/sum1) < eps;
success &= fabsf(result[1] - expf(4.0f)/sum1) < eps;
success &= fabsf(result[2] - expf(6.0f)/sum2) < eps;
success &= fabsf(result[3] - expf(8.0f)/sum2) < eps;

// cleanup
ggml_free(ctx);

return success;
}

int main(int argc, char ** argv) {
bool success = test_scale_diag_mask_inf_softmax();
printf("%s: %s\n", __func__, success ? "PASSED" : "FAILED");
return success ? 0 : 1;
}