@@ -1076,7 +1076,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
10761076 // prefetch the first block tile of A,B into shared memory
10771077// half* A_block_gmem = input + (block_m * BM * A_stride);
10781078 const half* A_block_gmem = input;
1079- const half* B_block_gmem = kernel + (block_n * weightKOffset);
1079+ // const half* B_block_gmem = kernel + (block_n * weightKOffset);
1080+ const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
10801081 tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
10811082 tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
10821083
@@ -1097,7 +1098,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
10971098 {
10981099 // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
10991100 const half* A_block_gmem = input;
1100- const half* B_block_gmem = kernel + (block_n * weightKOffset);
1101+ // const half* B_block_gmem = kernel + (block_n * weightKOffset);
1102+ const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
11011103 tileMemcpyLoadA<BM, BK, NUM_THREADS, 4 >(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
11021104 tileMemcpyLoadB<BN, BK, NUM_THREADS, 4 >(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
11031105 }
@@ -1119,6 +1121,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
11191121 {
11201122 asm volatile (
11211123 " mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
1124+ // "mma.sync.aligned.m16n8k8.row.row.f16.f16.f16.f16 "
11221125 " {%0, %1}, "
11231126 " {%2, %3}, "
11241127 " {%4}, "
@@ -1130,14 +1133,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
11301133 );
11311134 }
11321135 }
1133- // if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
1134- // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[3 ][0][0]), __half2float(acc_register_[3 ][0][1]),
1135- // __half2float(acc_register_[3 ][0][2]), __half2float(acc_register_[3 ][0][3]));
1136- // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3 ][mma_k][0]), __half2float(A_register_[3 ][mma_k][1]),
1137- // __half2float(A_register_[3 ][mma_k][2]), __half2float(A_register_[3 ][mma_k][3]));
1138- // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
1139- // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
1140- // }
1136+ if (threadIdx .x == 28 && threadIdx .y ==0 && blockIdx .x ==0 && blockIdx .y ==0 ){
1137+ printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (acc_register_[0 ][0 ][0 ]), __half2float (acc_register_[0 ][0 ][1 ]),
1138+ __half2float (acc_register_[0 ][0 ][2 ]), __half2float (acc_register_[0 ][0 ][3 ]));
1139+ printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (A_register_[0 ][mma_k][0 ]), __half2float (A_register_[0 ][mma_k][1 ]),
1140+ __half2float (A_register_[0 ][mma_k][2 ]), __half2float (A_register_[0 ][mma_k][3 ]));
1141+ printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (B_register_[mma_k][0 ][0 ]), __half2float (B_register_[mma_k][0 ][1 ]),
1142+ __half2float (B_register_[mma_k][0 ][2 ]), __half2float (B_register_[mma_k][0 ][3 ]));
1143+ }
11411144 // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
11421145 // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]));
11431146 // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
0 commit comments