Skip to content

Commit 02a3cb1

Browse files
author
bssrdf
committed
winograd build ok
1 parent 2ccc67d commit 02a3cb1

File tree

6 files changed

+1223
-790
lines changed

6 files changed

+1223
-790
lines changed

include/ggml.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,21 @@ extern "C" {
16981698
struct ggml_tensor * a,
16991699
struct ggml_tensor * b,
17001700
int stride);
1701+
1702+
GGML_API struct ggml_tensor * ggml_winograd_stage0(
1703+
struct ggml_context * ctx,
1704+
struct ggml_tensor * a);
1705+
1706+
GGML_API struct ggml_tensor * ggml_winograd_stage1(
1707+
struct ggml_context * ctx,
1708+
struct ggml_tensor * a,
1709+
struct ggml_tensor * b);
1710+
1711+
GGML_API struct ggml_tensor * ggml_conv_2d_3x3(
1712+
struct ggml_context * ctx,
1713+
struct ggml_tensor * a,
1714+
struct ggml_tensor * b);
1715+
17011716

17021717
enum ggml_op_pool {
17031718
GGML_OP_POOL_MAX,

src/ggml-cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "ggml-cuda/clamp.cuh"
1111
#include "ggml-cuda/concat.cuh"
1212
#include "ggml-cuda/conv-transpose-1d.cuh"
13+
#include "ggml-cuda/conv-winograd.cuh"
1314
#include "ggml-cuda/convert.cuh"
1415
#include "ggml-cuda/cpy.cuh"
1516
#include "ggml-cuda/cross-entropy-loss.cuh"
@@ -2331,6 +2332,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23312332
case GGML_OP_CONV_TRANSPOSE_1D:
23322333
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23332334
break;
2335+
case GGML_OP_WINOGRAD_STAGE0:
2336+
ggml_cuda_op_winograd_stage0(ctx, dst);
2337+
break;
2338+
case GGML_OP_WINOGRAD_STAGE1:
2339+
ggml_cuda_op_winograd_stage1(ctx, dst);
2340+
break;
23342341
case GGML_OP_POOL_2D:
23352342
ggml_cuda_op_pool2d(ctx, dst);
23362343
break;
@@ -2950,6 +2957,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29502957
}
29512958
return false;
29522959
} break;
2960+
case GGML_OP_WINOGRAD_STAGE0:
2961+
case GGML_OP_WINOGRAD_STAGE1:
29532962
case GGML_OP_NONE:
29542963
case GGML_OP_RESHAPE:
29552964
case GGML_OP_VIEW:

0 commit comments

Comments
 (0)