@@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
183183                input_frag[(subcrs + 1 ) & 1 ][i + 4 ] = smeminput[load_flag * 128  * 8  + input_lds_addr + (subcrs + 1 ) * 128  + i + 32 ];
184184            }
185185
186+ //  #pragma unroll
187+ //              for (int i = 0; i < 8; ++i){
188+ //                  auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
189+ //  #pragma unroll
190+ //                  for (int j = 0; j < 8; ++j){
191+ //                      output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
192+ //                  }
193+ //              }
186194#pragma  unroll
187-             for  (int  i  = 0 ; i  < 8 ; ++i ){
188-                 auto  weight_frag_i = ggml_cuda_cast<float >(weight_frag[subcrs % 2 ][i]);
195+             for  (int  j  = 0 ; j  < 8 ; ++j ){
196+                 //   auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
189197#pragma  unroll
190-                 for  (int  j  = 0 ; j  < 8 ; ++j ){
191-                     output_frag[i][j ] += weight_frag_i  * input_frag[subcrs % 2 ][j];
198+                 for  (int  i  = 0 ; i  < 8 ; ++i ){
199+                     output_frag[j][i ] += ggml_cuda_cast< float >(weight_frag[subcrs %  2 ][i])  * input_frag[subcrs % 2 ][j];
192200                }
193201            }
194202        }
@@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
215223        for  (int  i = 0 ; i < 8 ; ++i){
216224#pragma  unroll
217225            for  (int  j = 0 ; j < 8 ; ++j){
218-                 output_frag[i][j] += ggml_cuda_cast<float >(weight_frag[1 ][i ]) * input_frag[1 ][j ];
226+                 output_frag[i][j] += ggml_cuda_cast<float >(weight_frag[1 ][j ]) * input_frag[1 ][i ];
219227            }
220228        }
221229    }
@@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
240248#pragma  unroll
241249                for  (int  subj = 0 ; subj < 4 ; ++subj){
242250                    //  output sts
243-                     smemoutput[output_sts_addr + subi  * 8  * 4  + subj ] = output_frag[i * 4  + subi][j * 4  + subj];
251+                     smemoutput[output_sts_addr + subj  * 8  * 4  + subi ] = output_frag[i * 4  + subi][j * 4  + subj];
244252                }
245253            }
246254            __syncthreads ();
247255
248256#pragma  unroll
249257            for  (int  subk = 0 ; subk < 16 ; ++subk){
250-                 int  outOffset = z * param.k  * param.Oh  * param.Ow  + (m_idx + i  * 16  + subk) * param.Oh  * param.Ow  + n_idx + j  * 32 ;
251-                 if  ((m_idx + i  * 16  + subk) < param.k  && (n_idx + j  * 32 ) < param.Oh  * param.Ow )
258+                 int  outOffset = z * param.k  * param.Oh  * param.Ow  + (m_idx + j  * 16  + subk) * param.Oh  * param.Ow  + n_idx + i  * 32 ;
259+                 if  ((m_idx + j  * 16  + subk) < param.k  && (n_idx + i  * 32 ) < param.Oh  * param.Ow )
252260                    output[outOffset] = smemoutput[output_lds_addr + subk * 32 ];
253261            }
254262        }
0 commit comments