@@ -85,8 +85,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8585 // Warp tile
8686 const uint lane_id = tx % WARPSIZE;
8787 const uint warp_id = tx / WARPSIZE;
88- const int mma_tid_x = warp_id / (BN / WN);
89- const int mma_tid_y = warp_id % (BN / WN);
88+ const int mma_tid_x = warp_id / (BN / WN);
89+ const int mma_tid_y = warp_id % (BN / WN);
9090
9191 // size of the warp subtile
9292 constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
@@ -449,7 +449,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
449449 const int n = (ksplit > 0 ) ? gemm_i / PQ : z;
450450 const int col = (ksplit > 0 ) ? gemm_i % PQ : gemm_i;
451451 if (n < param.n && row < param.k && col < param.Oh * param.Ow ){
452- const uint outOffset = ksplit > 0 ?
452+ const uint outOffset = ksplit > 0 ?
453453 z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
454454 row * param.Oh * param.Ow + col :
455455 z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
@@ -626,7 +626,7 @@ __device__ __forceinline__ void ldmatrix_b(
626626
627627 static_assert (mma_tiles_per_warp_k == 4 , " mma_tiles_per_warp_k must be 4" );
628628 static_assert (mma_tiles_per_warp_n == 8 , " mma_tiles_per_warp_n must be 8" );
629-
629+
630630 uint32_t (®_) [4 ][8 ] = reinterpret_cast <uint32_t (&)[4 ][8 ]>(reg);
631631 unsigned int logical_offset = (threadIdx .x % 32 ) * smem_stride;
632632 unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000 ) >> 4 );
@@ -739,11 +739,11 @@ constexpr unsigned int MMA_N = 8;
739739 constexpr int BUFFER_SIZE = BM * BK + BK * BN;
740740
741741 // declare register storage
742- // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
742+ // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
743743 uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2 ];
744744 uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2 ];
745745 uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
746-
746+
747747 // convenience cast to half for register storage
748748 half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4 ] = reinterpret_cast <half (&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4 ]>(acc_register);
749749 half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4 ] = reinterpret_cast <half (&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4 ]>(A_register);
@@ -827,7 +827,7 @@ constexpr unsigned int MMA_N = 8;
827827
828828 // reuse smem
829829 half *smemoutput = shmem;
830- const uint lane_id = threadIdx .x % WARPSIZE;
830+ const uint lane_id = threadIdx .x % WARPSIZE;
831831 const uint mma_row = lane_id / 4 ;
832832 const uint mma_col = lane_id % 4 ;
833833 const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2 ;
@@ -845,7 +845,7 @@ constexpr unsigned int MMA_N = 8;
845845 for (unsigned int mma_n = i * mma_tiles_per_warp_n/2 ; mma_n < (i+1 )*mma_tiles_per_warp_n/2 ; mma_n++)
846846 {
847847 uint32_t (®_)[2 ] = reinterpret_cast <uint32_t (&)[2 ]>(acc_register_[mma_m][mma_n]);
848- uint idx = output_sts_addr +
848+ uint idx = output_sts_addr +
849849 mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2 ) * MMA_N;
850850 idx = idx ^ ((idx & 0b1110000000 ) >> 4 );
851851 uint32_t * dst_ptr = reinterpret_cast <uint32_t *>(&smemoutput[idx]);
@@ -902,7 +902,7 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = {
902902};
903903
904904template <typename T, unsigned int CONV_SHAPE>
905- static void conv2d_implicit_cuda (const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
905+ static void conv2d_implicit_cuda (const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
906906
907907 const uint BM = conv_shapes[0 ][CONV_SHAPE];
908908 const uint BN = conv_shapes[1 ][CONV_SHAPE];
@@ -920,7 +920,7 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
920920 int threadz = 1 ; // threadz number per block
921921 dim3 thblock (NUM_THREADS, thready, threadz);
922922 dim3 grid (blockx, blocky, blockz);
923-
923+
924924 conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
925925 WNITER, TM, TN, NUM_THREADS, 1 , false , 0 ><<<grid, thblock, 0 , st>>> (X_D, K_D, Y_D, P);
926926}
@@ -991,6 +991,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
991991static void conv2d_implicit_cuda_f32 (ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
992992 conv2d_implicit_cuda<float , 1 >(X_D, K_D, Y_D, P, st);
993993 GGML_UNUSED (ctx);
994+ GGML_UNUSED (cc);
994995}
995996
996997void ggml_cuda_op_conv2d_implicit (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments