11#include " conv2d.cuh"
22#include " convert.cuh"
33
4- #include < cstdint>
5-
64struct conv_params {
75 const uint IW, IH;
86 const uint OW, OH;
@@ -88,6 +86,9 @@ template <typename layout> __device__ class float_mma {
8886#pragma unroll
8987 for (uint i = 0 ; i < num_acc; i++) {
9088 const uint e = lane_id + i * WARP_SIZE;
89+ if (e >= WMMA_M * WMMA_N) {
90+ continue ;
91+ }
9192 const uint m = e / WMMA_N;
9293 const uint n = e % WMMA_N;
9394
@@ -109,6 +110,9 @@ template <typename layout> __device__ class float_mma {
109110#pragma unroll
110111 for (uint i = 0 ; i < num_acc; i++) {
111112 const uint e = lane_id + i * WARP_SIZE;
113+ if (e >= WMMA_M * WMMA_N) {
114+ continue ;
115+ }
112116 const uint m = e / WMMA_N;
113117 const uint n = e % WMMA_N;
114118
@@ -164,6 +168,9 @@ template <typename layout> class half_mma {
164168# pragma unroll
165169 for (uint l = 0 ; l < tile_acc::ne; ++l) {
166170 const uint e = tile_acc::get_i (l) * WMMA_N + tile_acc::get_j (l);
171+ if (e >= WMMA_M * WMMA_N) {
172+ continue ;
173+ }
167174 const uint m = e / WMMA_N;
168175 const uint n = e % WMMA_N;
169176
@@ -313,8 +320,8 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
313320 const int in_y = calculate_input_coord (oh, kh, P.ST_Y , P.DL_Y , P.PD_Y );
314321 const int in_x = calculate_input_coord (ow, kw, P.ST_X , P.DL_X , P.PD_X );
315322 if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW ) {
316- const int64_t in_idx = layout::input_index (n, ic, in_y, in_x, P);
317- val = ggml_cuda_cast<T>(IN[in_idx]);
323+ const uint64_t in_idx = layout::input_index (n, ic, in_y, in_x, P);
324+ val = ggml_cuda_cast<T>(IN[in_idx]);
318325 }
319326 }
320327 B_sh[brow * BS_NOHOW + bcol] = val;
@@ -359,7 +366,7 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
359366}
360367
361368template <typename T, template <typename > class mma >
362- static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
369+ static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) {
363370 GGML_ASSERT (BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N);
364371
365372 const uint NUM_BL_OC = (P.OC + BS_OC - 1 ) / BS_OC;
0 commit comments