@@ -49,10 +49,11 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co
4949 const T* src = reinterpret_cast <const T*>(cx);
5050 T* dst = reinterpret_cast <T*>(cdst);
5151
52- const int64_t nmat = ne /(ne00 * ne01);
52+ const int64_t nmat = ne / (ne00 * ne01);
5353 const int64_t n = ne00 * ne01;
5454 // const int64_t n = ne01 * ne02;
5555 int width = ne01;
56+ int height = ne00;
5657 int x = blockIdx .x * TILE_DIM + threadIdx .x ;
5758 int y = blockIdx .y * TILE_DIM + threadIdx .y ;
5859 int tx = blockIdx .y * TILE_DIM + threadIdx .x ; // transpose block offset
@@ -62,29 +63,65 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co
6263 __shared__ T tile[TILE_DIM][TILE_DIM];
6364
6465 for (int i = 0 ; i < BLOCK_NM; ++i){
66+ __syncthreads ();
67+
6568 const unsigned int imat = blockIdx .z * BLOCK_NM + i;
66- if (imat < nmat){
67- for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
69+ if (imat >= nmat)
70+ break ;
71+ for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
72+ if (imat < nmat && x < width && y + j < height){
6873 const unsigned int idx = (y+j)*width + x;
69- if (idx < n){
70- const int row = threadIdx .y +j;
71- const int col = threadIdx .x ^ row;
72- // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
73- tile[row][col] = src[imat*n + idx];
74- }
74+ const int row = threadIdx .y +j;
75+ const int col = threadIdx .x ^ row;
76+ // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
77+ tile[row][col] = src[imat*n + idx];
7578 }
76- __syncthreads ();
77-
78- for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
79- const unsigned int idx = (ty+j)*width + tx;
80- if (idx < n){
81- // const int row = threadIdx.x;
82- const int col = (threadIdx .y +j) ^ threadIdx .x ;
83- // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
84- dst[imat*n + idx] = tile[threadIdx .x ][col];
85- }
79+ }
80+ __syncthreads ();
81+
82+
83+ // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
84+ // printf("BEGIN %d\n", i);
85+ // for(int jj = 0; jj < TILE_DIM; ++jj){
86+ // for(int ii = 0; ii < TILE_DIM; ++ii)
87+ // printf("%.f, ", tile[jj][ii]);
88+ // printf("]\n");
89+ // }
90+ // }
91+
92+ for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
93+
94+ if (imat < nmat && ty + j < width && tx < height){
95+ const unsigned int idx = (ty+j)*height + tx;
96+ // const int row = threadIdx.x;
97+ const int col = (threadIdx .y +j) ^ threadIdx .x ;
98+ // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
99+ dst[imat*n + idx] = tile[threadIdx .x ][col];
100+ // if(imat*n + idx == 4*ne00){
101+ // printf("DEBUG: (%u, %u, %u, %u, %u), j=%d, tx=%d, ty=%d, imat=%u idx=%u dst[%u]=%.2f, %f\n",
102+ // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, j, tx, ty,
103+ // imat, idx, imat*n + idx, dst[imat*n + idx], tile[threadIdx.x][threadIdx.y + j]);
104+ // }
86105 }
87106 }
107+ // }
108+ }
109+
110+ if (threadIdx .x == 0 && threadIdx .y == 0 && blockIdx .x == 0 && blockIdx .y == 0 && blockIdx .z == 0 ){
111+ // for(int j = 0; j < 32; ++j){
112+ // j = 0;
113+ for (int i = 0 ; i < 32 ; ++i)
114+ // printf("%.2f, ", src[j*48+i]);
115+ // printf("%.2f, ", src[j*48+i]);
116+ printf (" %.2f, " , __half2float (src[i]));
117+ printf (" ]\n " );
118+ // }
119+ printf (" ==============================\n " );
120+ // for(int j = 0; j < 32; ++j){
121+ for (int i = 0 ; i < 32 ; ++i)
122+ printf (" %.2f, " , __half2float (dst[i]));
123+ printf (" ]\n " );
124+ // }
88125 }
89126}
90127
@@ -195,11 +232,11 @@ static void ggml_cpy_flt_cuda(
195232 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
196233 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
197234
198- if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
235+ if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
199236 std::is_same_v<src_t , float > && std::is_same_v<dst_t , float >)
200237 && transpose){
201- // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
202- // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
238+ printf (" cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n " , ne, ne00, ne01, ne10, ne11);
239+ printf (" cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n " , nb00, nb01, nb10, nb11);
203240 // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
204241 // if (transpose) { //transpose
205242 // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
0 commit comments