@@ -931,7 +931,8 @@ __device__ __forceinline__ void ldmatrix_b(
931931
932932
933933 asm volatile (
934- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
934+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
935+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
935936 " {%0, %1, %2, %3}, [%4];"
936937 : " =r" (reg_[0 ][4 ]), " =r" (reg_[0 ][5 ]), " =r" (reg_[0 ][6 ]), " =r" (reg_[0 ][7 ])
937938 // : "r"(src_addr ^ 0b1000000)
@@ -941,14 +942,16 @@ __device__ __forceinline__ void ldmatrix_b(
941942 src_addr ^= 0b10000 ;
942943
943944 asm volatile (
944- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
945+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
946+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
945947 " {%0, %1, %2, %3}, [%4];"
946948 : " =r" (reg_[1 ][0 ]), " =r" (reg_[1 ][1 ]), " =r" (reg_[1 ][2 ]), " =r" (reg_[1 ][3 ])
947949 : " r" (src_addr)
948950 );
949951
950952 asm volatile (
951- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
953+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
954+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
952955 " {%0, %1, %2, %3}, [%4];"
953956 : " =r" (reg_[1 ][4 ]), " =r" (reg_[1 ][5 ]), " =r" (reg_[1 ][6 ]), " =r" (reg_[1 ][7 ])
954957 // : "r"(src_addr ^ 0b1000000)
@@ -959,14 +962,16 @@ __device__ __forceinline__ void ldmatrix_b(
959962 src_addr ^= 0b110000 ;
960963
961964 asm volatile (
962- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
965+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
966+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
963967 " {%0, %1, %2, %3}, [%4];"
964968 : " =r" (reg_[2 ][0 ]), " =r" (reg_[2 ][1 ]), " =r" (reg_[2 ][2 ]), " =r" (reg_[2 ][3 ])
965969 : " r" (src_addr)
966970 );
967971
968972 asm volatile (
969- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
973+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
974+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
970975 " {%0, %1, %2, %3}, [%4];"
971976 : " =r" (reg_[2 ][4 ]), " =r" (reg_[2 ][5 ]), " =r" (reg_[2 ][6 ]), " =r" (reg_[2 ][7 ])
972977 // : "r"(src_addr ^ 0b1000000)
@@ -976,14 +981,16 @@ __device__ __forceinline__ void ldmatrix_b(
976981 src_addr ^= 0b10000 ;
977982
978983 asm volatile (
979- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
984+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
985+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
980986 " {%0, %1, %2, %3}, [%4];"
981987 : " =r" (reg_[3 ][0 ]), " =r" (reg_[3 ][1 ]), " =r" (reg_[3 ][2 ]), " =r" (reg_[3 ][3 ])
982988 : " r" (src_addr)
983989 );
984990
985991 asm volatile (
986- " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
992+ // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
993+ " ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
987994 " {%0, %1, %2, %3}, [%4];"
988995 : " =r" (reg_[3 ][4 ]), " =r" (reg_[3 ][5 ]), " =r" (reg_[3 ][6 ]), " =r" (reg_[3 ][7 ])
989996 // : "r"(src_addr ^ 0b1000000)
@@ -1043,6 +1050,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
10431050 // declare register storage
10441051 // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
10451052 uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2 ];
1053+ // float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
10461054 uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2 ];
10471055 uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
10481056
@@ -1131,16 +1139,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
11311139 " r" (B_register[mma_k][mma_n])
11321140 " r" (acc_register[mma_m][mma_n][0 ]), " r" (acc_register[mma_m][mma_n][1 ])
11331141 );
1142+ // asm volatile (
1143+ // "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
1144+ // "{%0, %1, %2, %3},"
1145+ // "{%4, %5},"
1146+ // "{%6},"
1147+ // "{%7, %8, %9, %10};\n"
1148+ // : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]),
1149+ // "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3])
1150+ // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
1151+ // "r"(B_register[mma_k][mma_n]),
1152+ // "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]),
1153+ // "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3])
1154+ // );
11341155 }
11351156 }
1136- if (threadIdx .x == 28 && threadIdx .y ==0 && blockIdx .x ==0 && blockIdx .y ==0 ){
1137- printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (acc_register_[0 ][0 ][0 ]), __half2float (acc_register_[0 ][0 ][1 ]),
1138- __half2float (acc_register_[0 ][0 ][2 ]), __half2float (acc_register_[0 ][0 ][3 ]));
1139- printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (A_register_[0 ][mma_k][0 ]), __half2float (A_register_[0 ][mma_k][1 ]),
1140- __half2float (A_register_[0 ][mma_k][2 ]), __half2float (A_register_[0 ][mma_k][3 ]));
1141- printf (" %d, %d: %f, %f, %f, %f \n " , block_k, mma_k, __half2float (B_register_[mma_k][0 ][0 ]), __half2float (B_register_[mma_k][0 ][1 ]),
1142- __half2float (B_register_[mma_k][0 ][2 ]), __half2float (B_register_[mma_k][0 ][3 ]));
1143- }
1157+ // if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
1158+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
1159+ // __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
1160+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1],
1161+ // acc_register_[0][0][2], acc_register_[0][0][3]);
1162+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
1163+ // __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
1164+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
1165+ // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
1166+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1],
1167+ // acc_register_[1][0][2], acc_register_[1][0][3]);
1168+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]),
1169+ // __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3]));
1170+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1],
1171+ // acc_register_[3][0][2], acc_register_[3][0][3]);
1172+ // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]),
1173+ // __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3]));
1174+ // printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
1175+ // }
11441176 // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
11451177 // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]));
11461178 // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
0 commit comments