Skip to content

Commit 3308cce

Browse files
committed
conv3d WIP: enabled tensor core path
1 parent 3f5c504 commit 3308cce

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

ggml/src/ggml-cuda/conv3d-implicit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,9 +1007,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
10071007

10081008
int id = ggml_cuda_get_device();
10091009

1010-
int64_t ne = P.c * P.h * P.w * P.n;
1010+
int64_t ne = P.c * P.d * P.h * P.w * P.n;
10111011
int64_t ne00 = P.c;
1012-
int64_t ne01 = P.h * P.w;
1012+
int64_t ne01 = P.h * P.w * P.d;
10131013
ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
10141014

10151015
dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
@@ -1018,8 +1018,8 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
10181018
dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
10191019
NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
10201020

1021-
ne = P.c * P.r * P.s * P.k;
1022-
ne01 = P.r * P.s;
1021+
ne = P.c * P.r * P.s * P.t * P.k;
1022+
ne01 = P.r * P.s * P.t;
10231023
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
10241024
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
10251025
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,

tests/test-conv3d.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,13 @@ int main(void)
323323
// std::make_tuple(960,320,104,152,3,3),
324324
// std::make_tuple(1280,1280,26,38,3,3),
325325
std::make_tuple(320,1280,26,38,8,3,3,3),
326-
// std::make_tuple(1280,1280,26,38,8,3,3,3),
327-
// std::make_tuple(320,1280,52,76,8,3,3,3),
328-
// std::make_tuple(1280,1280,52,76,8,3,3,3),
329-
// std::make_tuple(320,1280,104,152,8,3,3,3),
330-
// std::make_tuple(1280,1280,104,152,8,3,3,3),
331-
// std::make_tuple(320,1280,208,304,4,3,3,3),
332-
// std::make_tuple(640,1280,208,304,4,3,3,3),
326+
std::make_tuple(1280,1280,26,38,8,3,3,3),
327+
std::make_tuple(320,1280,52,76,8,3,3,3),
328+
std::make_tuple(1280,1280,52,76,8,3,3,3),
329+
std::make_tuple(320,1280,104,152,8,3,3,3),
330+
std::make_tuple(1280,1280,104,152,8,3,3,3),
331+
std::make_tuple(320,1280,208,304,4,3,3,3),
332+
std::make_tuple(640,1280,208,304,4,3,3,3),
333333
// std::make_tuple(1280,1280,26,38,1,1),
334334
// std::make_tuple(256,128,768,1024,3,3),
335335
// std::make_tuple(128,3,768,1024,3,3),
@@ -367,7 +367,7 @@ int main(void)
367367

368368

369369
struct ggml_cgraph * gf_res_0 = NULL;
370-
int iterations = 0;
370+
int iterations = 20;
371371

372372
double run_time0;
373373
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,

0 commit comments

Comments
 (0)