Skip to content

Commit ad7d7ff

Browse files
committed
Scaffolding for snake activation fn
SNAC uses the snake activation function. Added scaffolding to include `GGML_OP_SNAKE` as a new op. Should this be a unary op? The SNAC decoder uses noise blocks to enhance outputs, its optional, so omitting it for now until the model is integrated e2e. Next steps: write the `llm_graph_context` for SNAC
1 parent 9906bd9 commit ad7d7ff

File tree

4 files changed

+143
-4
lines changed

4 files changed

+143
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,7 +2329,7 @@ def set_gguf_parameters(self):
23292329

23302330
@Model.register("SNACDec")
23312331
class SNACDecModel(Model):
2332-
model_arch = gguf.MODEL_ARCH.SNAC_DEC # Assumes this constant is defined in gguf
2332+
model_arch = gguf.MODEL_ARCH.SNAC_DEC
23332333

23342334
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]:
23352335
del bid # unused
@@ -2357,7 +2357,6 @@ def set_gguf_parameters(self):
23572357
self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"])
23582358
self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"])
23592359
self.gguf_writer.add_decoder_channel_dims(self.hparams["decoder_channel_dims"])
2360-
self.gguf_writer.add_convnext_block_count(1)
23612360

23622361
@Model.register("Qwen2MoeForCausalLM")
23632362
class Qwen2MoeModel(Model):

ggml/include/ggml.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ extern "C" {
492492
GGML_OP_TIMESTEP_EMBEDDING,
493493
GGML_OP_ARGSORT,
494494
GGML_OP_LEAKY_RELU,
495+
GGML_OP_SNAKE,
495496

496497
GGML_OP_FLASH_ATTN_EXT,
497498
GGML_OP_FLASH_ATTN_BACK,
@@ -1062,6 +1063,16 @@ extern "C" {
10621063
struct ggml_context * ctx,
10631064
struct ggml_tensor * a);
10641065

1066+
GGML_API struct ggml_tensor * ggml_snake(
1067+
struct ggml_context * ctx,
1068+
struct ggml_tensor * a,
1069+
struct ggml_tensor * alpha);
1070+
1071+
GGML_API struct ggml_tensor * ggml_snake_inplace(
1072+
struct ggml_context * ctx,
1073+
struct ggml_tensor * a,
1074+
struct ggml_tensor * alpha);
1075+
10651076
// normalize along rows
10661077
GGML_API struct ggml_tensor * ggml_norm(
10671078
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,21 @@ inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const
19111911
y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
19121912
}
19131913
}
1914+
inline static void ggml_vec_snake_f32(const int n, float * y, const float * x, const float a) {
1915+
for (int i = 0; i < n; ++i) {
1916+
float x_val = x[i];
1917+
float sin_val = sinf(a * x_val);
1918+
y[i] = x_val + sin_val * sin_val;
1919+
}
1920+
}
1921+
inline static void ggml_vec_snake_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t a) {
1922+
for (int i = 0; i < n; ++i) {
1923+
float x_val = GGML_FP16_TO_FP32(x[i]); // TODO: double check this conversion
1924+
float a_val = GGML_FP16_TO_FP32(a);
1925+
float sin_val = sinf(a_val * x_val);
1926+
y[i] = GGML_FP32_TO_FP16(x_val + sin_val * sin_val);
1927+
}
1928+
}
19141929
inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
19151930
inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
19161931
for (int i = 0; i < n; ++i) {
@@ -7817,6 +7832,86 @@ static void ggml_compute_forward_leaky_relu(
78177832
}
78187833
}
78197834

7835+
// ggml_compute_forward_snake
7836+
7837+
static void ggml_compute_forward_snake_f32(
7838+
const struct ggml_compute_params * params,
7839+
struct ggml_tensor * dst) {
7840+
const struct ggml_tensor * src0 = dst->src[0];
7841+
7842+
// Scaffold code, 1 thread for now
7843+
// TODO: add multithreading
7844+
if (params->ith != 0) {
7845+
return;
7846+
}
7847+
7848+
struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params);
7849+
const float * x = (const float *)src0->data;
7850+
const float * a = (const float *)alpha->data;
7851+
float * y = (float *)dst->data;
7852+
7853+
const int n = ggml_nrows(src0);
7854+
const int nc = src0->ne[0];
7855+
const int channels = src0->ne[1];
7856+
7857+
for (int i = 0; i < n; i++) {
7858+
int c = i % channels;
7859+
ggml_vec_snake_f32(nc,
7860+
(float *) ((char *) y + i * dst->nb[1]),
7861+
(const float *) ((const char *) x + i * src0->nb[1]),
7862+
a[c]); // alpha tensor for this channel
7863+
}
7864+
}
7865+
7866+
static void ggml_compute_forward_snake_f16(
7867+
const struct ggml_compute_params * params,
7868+
struct ggml_tensor * dst) {
7869+
const struct ggml_tensor * src0 = dst->src[0];
7870+
7871+
if (params->ith != 0) {
7872+
return;
7873+
}
7874+
7875+
struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params);
7876+
const ggml_fp16_t * x = (const ggml_fp16_t *)src0->data;
7877+
const ggml_fp16_t * a = (const ggml_fp16_t *)alpha->data;
7878+
ggml_fp16_t * y = (ggml_fp16_t *)dst->data;
7879+
7880+
const int n = ggml_nrows(src0);
7881+
const int nc = src0->ne[0];
7882+
const int channels = src0->ne[1];
7883+
7884+
for (int i = 0; i < n; i++) {
7885+
int c = i % channels;
7886+
ggml_vec_snake_f16(nc,
7887+
(ggml_fp16_t *) ((char *) y + i * dst->nb[1]),
7888+
(const ggml_fp16_t *) ((const char *) x + i * src0->nb[1]),
7889+
a[c]);
7890+
}
7891+
}
7892+
7893+
static void ggml_compute_forward_snake(
7894+
const struct ggml_compute_params * params,
7895+
struct ggml_tensor * dst) {
7896+
7897+
const struct ggml_tensor * src0 = dst->src[0];
7898+
7899+
switch (src0->type) {
7900+
case GGML_TYPE_F32:
7901+
{
7902+
ggml_compute_forward_snake_f32(params, dst);
7903+
} break;
7904+
case GGML_TYPE_F16:
7905+
{
7906+
ggml_compute_forward_snake_f16(params, dst);
7907+
} break;
7908+
default:
7909+
{
7910+
GGML_ABORT("fatal error");
7911+
}
7912+
}
7913+
}
7914+
78207915
// ggml_compute_forward_silu_back
78217916

78227917
static void ggml_compute_forward_silu_back_f32(
@@ -14555,6 +14650,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1455514650
{
1455614651
ggml_compute_forward_leaky_relu(params, tensor);
1455714652
} break;
14653+
case GGML_OP_SNAKE:
14654+
{
14655+
ggml_compute_forward_snake(params, tensor);
14656+
} break;
1455814657
case GGML_OP_FLASH_ATTN_EXT:
1455914658
{
1456014659
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);

ggml/src/ggml.c

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
967967
"TIMESTEP_EMBEDDING",
968968
"ARGSORT",
969969
"LEAKY_RELU",
970+
"SNAKE",
970971

971972
"FLASH_ATTN_EXT",
972973
"FLASH_ATTN_BACK",
@@ -998,7 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
998999
"OPT_STEP_ADAMW",
9991000
};
10001001

1001-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1002+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
10021003

10031004
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10041005
"none",
@@ -1097,7 +1098,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10971098
"adamw(x)",
10981099
};
10991100

1100-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1101+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
11011102

11021103
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11031104

@@ -2474,6 +2475,35 @@ struct ggml_tensor * ggml_leaky_relu(
24742475
return result;
24752476
}
24762477

2478+
// ggml snake
2479+
2480+
struct ggml_tensor * ggml_snake(
2481+
struct ggml_context * ctx,
2482+
struct ggml_tensor * a,
2483+
struct ggml_tensor * alpha) {
2484+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
2485+
2486+
// store ptr to alpha tensor
2487+
ggml_set_op_params(result, &alpha, sizeof(alpha));
2488+
result->op = GGML_OP_SNAKE;
2489+
result->src[0] = a;
2490+
2491+
return result;
2492+
}
2493+
2494+
struct ggml_tensor * ggml_snake_inplace(
2495+
struct ggml_context * ctx,
2496+
struct ggml_tensor * a,
2497+
struct ggml_tensor * alpha) {
2498+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
2499+
2500+
ggml_set_op_params(result, &alpha, sizeof(alpha));
2501+
result->op = GGML_OP_SNAKE;
2502+
result->src[0] = a;
2503+
2504+
return result;
2505+
}
2506+
24772507
// ggml_sigmoid
24782508

24792509
struct ggml_tensor * ggml_sigmoid(

0 commit comments

Comments
 (0)