Skip to content

Commit 5b4e448

Browse files
author
bssrdf
committed
added two winograd ops
1 parent e6643c6 commit 5b4e448

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ extern "C" {
510510
GGML_OP_TIMESTEP_EMBEDDING,
511511
GGML_OP_ARGSORT,
512512
GGML_OP_LEAKY_RELU,
513+
GGML_OP_WINOGRAD_STAGE0,
514+
GGML_OP_WINOGRAD_STAGE1,
513515

514516
GGML_OP_FLASH_ATTN_EXT,
515517
GGML_OP_FLASH_ATTN_BACK,

src/ggml.c

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2995,6 +2995,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
29952995
"TIMESTEP_EMBEDDING",
29962996
"ARGSORT",
29972997
"LEAKY_RELU",
2998+
"WINOGRAD_STAGE0",
2999+
"WINOGRAD_STAGE1",
29983000

29993001
"FLASH_ATTN_EXT",
30003002
"FLASH_ATTN_BACK",
@@ -3089,6 +3091,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
30893091
"timestep_embedding(timesteps, dim, max_period)",
30903092
"argsort(x)",
30913093
"leaky_relu(x)",
3094+
"winograd_stage0(x)",
3095+
"winograd_stage1(x)",
30923096

30933097
"flash_attn_ext(x)",
30943098
"flash_attn_back(x)",
@@ -3118,7 +3122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
31183122
"adamw(x)",
31193123
};
31203124

3121-
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
3125+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
31223126

31233127
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
31243128

@@ -7166,6 +7170,70 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0(
71667170
return result;
71677171
}
71687172

7173+
7174+
// ggml_winograd
7175+
7176+
// a: [OC,IC, 3, 3]
7177+
// result: [OC, IC, 16]
7178+
struct ggml_tensor * ggml_winograd_stage0(
7179+
struct ggml_context * ctx,
7180+
struct ggml_tensor * a) {
7181+
bool is_node = false;
7182+
GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3
7183+
if (a->grad) {
7184+
is_node = true;
7185+
}
7186+
7187+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 16, a->ne[2], a->ne[3], 1);
7188+
7189+
result->op = GGML_OP_WINOGRAD_STAGE0;
7190+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7191+
result->src[0] = a;
7192+
7193+
return result;
7194+
}
7195+
7196+
// ggml_winograd
7197+
// a: [OC, IC, 4, 4]
7198+
// b: [1, IC, IH, IW]
7199+
// result: [N, OC, OH, OW]
7200+
struct ggml_tensor * ggml_winograd_stage1(
7201+
struct ggml_context * ctx,
7202+
struct ggml_tensor * a,
7203+
struct ggml_tensor * b) {
7204+
bool is_node = false;
7205+
if (a->grad) {
7206+
is_node = true;
7207+
}
7208+
7209+
int OW = b->ne[0];
7210+
int OH = b->ne[1];
7211+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[3] /* OC */, 1);
7212+
7213+
result->op = GGML_OP_WINOGRAD_STAGE1;
7214+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7215+
result->src[0] = a;
7216+
result->src[1] = b;
7217+
7218+
return result;
7219+
}
7220+
7221+
struct ggml_tensor * ggml_conv_2d_3x3(
7222+
struct ggml_context * ctx,
7223+
struct ggml_tensor * a,
7224+
struct ggml_tensor * b){
7225+
7226+
7227+
GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image
7228+
7229+
struct ggml_tensor* W = ggml_winograd_stage0(ctx, a);
7230+
struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b);
7231+
7232+
return result;
7233+
7234+
}
7235+
7236+
71697237
// ggml_pool_*
71707238

71717239
static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {

0 commit comments

Comments
 (0)