Skip to content

Commit 0ca4358

Browse files
committed
reorder register tile loop
1 parent c625544 commit 0ca4358

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
183183
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
184184
}
185185

186+
// #pragma unroll
187+
// for (int i = 0; i < 8; ++i){
188+
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
189+
// #pragma unroll
190+
// for (int j = 0; j < 8; ++j){
191+
// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
192+
// }
193+
// }
186194
#pragma unroll
187-
for (int i = 0; i < 8; ++i){
188-
auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
195+
for (int j = 0; j < 8; ++j){
196+
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
189197
#pragma unroll
190-
for (int j = 0; j < 8; ++j){
191-
output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
198+
for (int i = 0; i < 8; ++i){
199+
output_frag[j][i] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
192200
}
193201
}
194202
}
@@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
215223
for (int i = 0; i < 8; ++i){
216224
#pragma unroll
217225
for (int j = 0; j < 8; ++j){
218-
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
226+
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][j]) * input_frag[1][i];
219227
}
220228
}
221229
}
@@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
240248
#pragma unroll
241249
for (int subj = 0; subj < 4; ++subj){
242250
// output sts
243-
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
251+
smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj];
244252
}
245253
}
246254
__syncthreads();
247255

248256
#pragma unroll
249257
for (int subk = 0; subk < 16; ++subk){
250-
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
251-
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
258+
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32;
259+
if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow)
252260
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
253261
}
254262
}

0 commit comments

Comments
 (0)