Skip to content

Commit 475f987

Browse files
committed
WIP: fixed another bug
1 parent 396f558 commit 475f987

File tree

2 files changed

+54
-22
lines changed

2 files changed

+54
-22
lines changed

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,8 @@ __device__ __forceinline__ void ldmatrix_b(
931931

932932

933933
asm volatile (
934-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
934+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
935+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
935936
"{%0, %1, %2, %3}, [%4];"
936937
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
937938
// : "r"(src_addr ^ 0b1000000)
@@ -941,14 +942,16 @@ __device__ __forceinline__ void ldmatrix_b(
941942
src_addr ^= 0b10000;
942943

943944
asm volatile (
944-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
945+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
946+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
945947
"{%0, %1, %2, %3}, [%4];"
946948
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
947949
: "r"(src_addr)
948950
);
949951

950952
asm volatile (
951-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
953+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
954+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
952955
"{%0, %1, %2, %3}, [%4];"
953956
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
954957
// : "r"(src_addr ^ 0b1000000)
@@ -959,14 +962,16 @@ __device__ __forceinline__ void ldmatrix_b(
959962
src_addr ^= 0b110000;
960963

961964
asm volatile (
962-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
965+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
966+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
963967
"{%0, %1, %2, %3}, [%4];"
964968
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
965969
: "r"(src_addr)
966970
);
967971

968972
asm volatile (
969-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
973+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
974+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
970975
"{%0, %1, %2, %3}, [%4];"
971976
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
972977
// : "r"(src_addr ^ 0b1000000)
@@ -976,14 +981,16 @@ __device__ __forceinline__ void ldmatrix_b(
976981
src_addr ^= 0b10000;
977982

978983
asm volatile (
979-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
984+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
985+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
980986
"{%0, %1, %2, %3}, [%4];"
981987
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
982988
: "r"(src_addr)
983989
);
984990

985991
asm volatile (
986-
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
992+
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
993+
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
987994
"{%0, %1, %2, %3}, [%4];"
988995
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
989996
// : "r"(src_addr ^ 0b1000000)
@@ -1043,6 +1050,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
10431050
// declare register storage
10441051
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
10451052
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
1053+
// float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
10461054
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
10471055
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
10481056

@@ -1131,16 +1139,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
11311139
"r"(B_register[mma_k][mma_n])
11321140
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
11331141
);
1142+
// asm volatile (
1143+
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
1144+
// "{%0, %1, %2, %3},"
1145+
// "{%4, %5},"
1146+
// "{%6},"
1147+
// "{%7, %8, %9, %10};\n"
1148+
// : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]),
1149+
// "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3])
1150+
// : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
1151+
// "r"(B_register[mma_k][mma_n]),
1152+
// "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]),
1153+
// "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3])
1154+
// );
11341155
}
11351156
}
1136-
if(threadIdx.x == 28 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
1137-
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
1138-
__half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
1139-
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
1140-
__half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
1141-
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
1142-
__half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
1143-
}
1157+
// if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
1158+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
1159+
// __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
1160+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1],
1161+
// acc_register_[0][0][2], acc_register_[0][0][3]);
1162+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
1163+
// __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
1164+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
1165+
// __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
1166+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1],
1167+
// acc_register_[1][0][2], acc_register_[1][0][3]);
1168+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]),
1169+
// __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3]));
1170+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1],
1171+
// acc_register_[3][0][2], acc_register_[3][0][3]);
1172+
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]),
1173+
// __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3]));
1174+
// printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
1175+
// }
11441176
// if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
11451177
// printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]));
11461178
// printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));

tests/test-conv2d-implicit.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu
5050
std::vector<float> adata(KW * KH * IC * OC);
5151
for (int i = 0; i < KW * KH * IC * OC; i++) {
5252
// adata[i] = 2.f;
53-
adata[i] = (float)(i%KW)-1.f;
53+
// adata[i] = (float)(i%KW)-1.f;
5454
// adata[i] = (rand() % 255) / 255.0;
55-
// float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
56-
// adata[i] = r;
55+
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
56+
adata[i] = r;
5757
}
5858

5959
// Convert adata to fp16 format
@@ -63,11 +63,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu
6363
// Initialize bdata
6464
std::vector<float> bdata(IW * IH * IC * N);
6565
for (int i = 0; i < IW * IH * IC * N; i++) {
66-
bdata[i] = (float)(i%IW)/10.f;
66+
// bdata[i] = (float)(i%IW)/10.f;
6767
// bdata[i] = 1.5f;
6868
// bdata[i] = (rand() % 255) / 255.0;
69-
// float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
70-
// bdata[i] = r;
69+
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
70+
bdata[i] = r;
7171
}
7272

7373
size_t buffer_size = 0;
@@ -452,7 +452,7 @@ int main(void)
452452
float diff = fabs(im2col_data[i] - wino_data[i]);
453453
float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
454454
// if(diff > 1.e-4) {
455-
printf("(%f, %f, %f, %f, %f, %d) \n",
455+
printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
456456
im2col_data[i], conv2d_data[i],
457457
wino_data[i], diff, diff1, i);
458458
// break;

0 commit comments

Comments
 (0)