77
88typedef unsigned int uint;
99constexpr uint WARPSIZE = 32 ;
10+ #define CUDA_NCHW_2_NHWC_TILE_DIM 32
11+ #define CUDA_NCHW_2_NHWC_BLOCK_NM 8
12+ #define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8
1013
1114
1215// currently not use; in future for split-k kernels
@@ -23,6 +26,41 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri
2326 }
2427}
2528
29+ template <typename src_T, typename dst_T>
30+ static __global__ void NCHW2NHWC (const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
31+
32+ const int64_t nmat = ne / (ne00 * ne01);
33+ const int64_t n = ne00 * ne01;
34+
35+ int x = blockIdx .x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx .x ;
36+ int y = blockIdx .y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx .y ;
37+ int tx = blockIdx .y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx .x ; // transpose block offset
38+ int ty = blockIdx .x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx .y ;
39+
40+ __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM];
41+
42+ for (int i = 0 ; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
43+
44+ const unsigned int imat = blockIdx .z * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
45+ if (imat >= nmat)
46+ break ;
47+ for (int j = 0 ; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
48+ if (x < ne01 && y + j < ne00){
49+ const int row = threadIdx .y +j;
50+ const int col = threadIdx .x ^ row;
51+ tile[row][col] = src[imat*n + (y+j)*ne01 + x];
52+ }
53+ }
54+ __syncthreads ();
55+
56+ for (int j = 0 ; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
57+ if (ty + j < ne01 && tx < ne00){
58+ const int col = (threadIdx .y +j) ^ threadIdx .x ;
59+ dst[imat*n + (ty+j)*ne00 + tx] = ggml_cuda_cast<dst_T>(tile[threadIdx .x ][col]);
60+ }
61+ }
62+ }
63+ }
2664
2765template <typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
2866 const int WNITER, const int TM, const int TN, const int NUM_THREADS,
@@ -882,26 +920,40 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
882920 int threadz = 1 ; // threadz number per block
883921 dim3 thblock (NUM_THREADS, thready, threadz);
884922 dim3 grid (blockx, blocky, blockz);
885- if (P.c % 4 == 0 ){
886- if (P.layout == 0 )
887- conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
888- WNITER, TM, TN, NUM_THREADS, 0 , true , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
889- else if (P.layout == 1 )
890- conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
891- WNITER, TM, TN, NUM_THREADS, 1 , false , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
892- } else {
893- if (P.layout == 0 )
894- conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
895- WNITER, TM, TN, NUM_THREADS, 0 , false , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
896- else if (P.layout == 1 )
897- conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
898- WNITER, TM, TN, NUM_THREADS, 1 , false , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
899- }
923+
924+ conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
925+ WNITER, TM, TN, NUM_THREADS, 1 , false , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
900926}
901927
902928static void conv2d_implicit_cuda_f16 (ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
903929
904- if (GGML_CUDA_CC_IS_NVIDIA (cc) && ampere_mma_available (cc) && P.layout == 0 && P.c % 8 == 0 ) {
930+ if (GGML_CUDA_CC_IS_NVIDIA (cc) && ampere_mma_available (cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 )) {
931+
932+ int id = ggml_cuda_get_device ();
933+
934+ int64_t ne = P.c * P.h * P.w * P.n ;
935+ int64_t ne00 = P.c ;
936+ int64_t ne01 = P.h * P.w ;
937+ ggml_cuda_pool_alloc<half> input_f16 (ctx.pool (id), ne);
938+
939+ dim3 dimGrid ( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1 ) / CUDA_NCHW_2_NHWC_TILE_DIM,
940+ (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1 ) / CUDA_NCHW_2_NHWC_TILE_DIM,
941+ (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1 ) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
942+ dim3 dimBlock (CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1 );
943+ NCHW2NHWC<float , half><<<dimGrid, dimBlock, 0 , st>>> (X_D, input_f16.get (), ne, ne00, ne01);
944+
945+ ne = P.c * P.r * P.s * P.k ;
946+ ne01 = P.r * P.s ;
947+ ggml_cuda_pool_alloc<half> kernel_f16 (ctx.pool (id), ne);
948+ dim3 dimGrid1 ((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1 ) / CUDA_NCHW_2_NHWC_TILE_DIM,
949+ (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1 ) / CUDA_NCHW_2_NHWC_TILE_DIM,
950+ (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1 ) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
951+ NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0 , st>>> (K_D, kernel_f16.get (), ne, ne00, ne01);
952+
953+ const half *X_H = input_f16.get ();
954+ const half *K_H = kernel_f16.get ();
955+ ggml_cuda_pool_alloc<half> Y_H (ctx.pool (id), P.k * P.Oh * P.Ow * P.n );
956+
905957 constexpr unsigned int BM_dim = 256 ;
906958 constexpr unsigned int BN_dim = 256 ;
907959 constexpr unsigned int BK_dim = 32 ;
@@ -925,19 +977,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
925977 dim3 gridDim (BlocksN, BlocksM);
926978 dim3 blockDim (ThreadsN, ThreadsM);
927979
928- int id = ggml_cuda_get_device ();
929- ggml_cuda_pool_alloc<half> x_f16 (ctx.pool (id));
930-
931- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (GGML_TYPE_F32);
932- GGML_ASSERT (to_fp16_cuda != nullptr );
933- size_t ne = P.c * P.h * P.w * P.n ;
934- x_f16.alloc (ne);
935- to_fp16_cuda (X_D, x_f16.get (), ne, st);
936- const half *X_H = x_f16.get ();
937- ggml_cuda_pool_alloc<half> Y_H (ctx.pool (id), P.k * P.Oh * P.Ow * P.n );
938980 conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
939981 WM_dim, WN_dim, WK_dim, NumThreads>
940- <<<gridDim , blockDim , shmem_bytes, st>>> (X_H, K_D , Y_H.get (), P);
982+ <<<gridDim , blockDim , shmem_bytes, st>>> (X_H, K_H , Y_H.get (), P);
941983 const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
942984 to_fp32_cuda (Y_H.get (), Y_D, P.k * P.Oh * P.Ow * P.n , st);
943985 } else {
@@ -971,36 +1013,36 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
9711013 const int PD_Y = p[3 ]; // padding_y
9721014 const int DL_X = p[4 ]; // dilation_x
9731015 const int DL_Y = p[5 ]; // dilation_y
974- const int LT = p[6 ]; // layout
1016+ // const int LT = p[6]; // layout
9751017
976- GGML_ASSERT (LT == 0 || LT == 1 );
1018+ // GGML_ASSERT(LT == 0 || LT == 1);
9771019
9781020 // same number of input channels
979- GGML_ASSERT (LT == 0 ? input->ne [0 ] == kernel->ne [0 ] : input->ne [2 ] == kernel->ne [2 ]);
1021+ // GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]);
9801022 // No cwhn
981- GGML_ASSERT (p[7 ] == false );
1023+ GGML_ASSERT (p[6 ] == false );
9821024
983- const int IW = input->ne [LT == 0 ? 1 : 0 ]; // input_w
984- const int IH = input->ne [LT == 0 ? 2 : 1 ]; // input_h
1025+ const int IW = input->ne [0 ]; // input_w
1026+ const int IH = input->ne [1 ]; // input_h
9851027 const int OW = dst->ne [0 ]; // output_w
9861028 const int OH = dst->ne [1 ]; // output_h
987- const int KW = kernel->ne [LT == 0 ? 1 : 0 ]; // kernel_w
988- const int KH = kernel->ne [LT == 0 ? 2 : 1 ]; // kernel_h
989- const int IC = input->ne [LT == 0 ? 0 : 2 ]; // input_channels
1029+ const int KW = kernel->ne [0 ]; // kernel_w
1030+ const int KH = kernel->ne [1 ]; // kernel_h
1031+ const int IC = input->ne [2 ]; // input_channels
9901032
9911033 const int OC = kernel->ne [3 ]; // ouptut_chanles
9921034 const int B = input->ne [3 ]; // n_batches
993-
1035+
9941036 const int64_t total = B * OC * OH * OW;
995-
1037+
9961038 param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW };
9971039 params.SC_fastdiv = init_fastdiv_values (KW*IC);
9981040 params.OW_fastdiv = init_fastdiv_values (OW);
9991041 params.OHOW_fastdiv = init_fastdiv_values (OW*OH);
10001042 params.C_fastdiv = init_fastdiv_values (IC);
10011043 params.RS_fastdiv = init_fastdiv_values (KW*KH);
10021044 params.S_fastdiv = init_fastdiv_values (KW);
1003- params.layout = LT;
1045+ // params.layout = LT;
10041046
10051047 if (kernel->type == GGML_TYPE_F16) {
10061048 conv2d_implicit_cuda_f16 (ctx, X_D, (half *) K_D, Y_D, cc, params, st);
0 commit comments