@@ -56,7 +56,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
5656 const int tx = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .x ; // transpose block offset
5757 const int ty = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
5858
59- __shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];
59+ __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];
6060
6161#pragma unroll
6262 for (int i = 0 ; i < CUDA_CPY_BLOCK_NM; ++i) {
@@ -69,8 +69,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
6969 for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
7070 if (x < ne01 && y + j < ne00){
7171 const int row = threadIdx .y +j;
72- const int col = threadIdx .x ^ row; // swizzling to avoid bank conflicts
73- tile[row][col] = src[imat*n + (y+j)*ne01 + x];
72+ const int col = (threadIdx .x *sizeof (float )/sizeof (T)) ^ row; // swizzling to avoid bank conflicts
73+ T *tile2 = reinterpret_cast <T*>(tile[row]);
74+ tile2[col] = src[imat*n + (y+j)*ne01 + x];
7475 }
7576 }
7677
@@ -79,8 +80,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
7980#pragma unroll
8081 for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
8182 if (ty + j < ne01 && tx < ne00) {
82- const int col = (threadIdx .y +j) ^ threadIdx .x ; // swizzling to avoid bank conflicts
83- dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx .x ][col];
83+ const int col = ((threadIdx .y +j)*sizeof (float )/sizeof (T)) ^ threadIdx .x ; // swizzling to avoid bank conflicts
84+ T *tile2 = reinterpret_cast <T*>(tile[threadIdx .x ]);
85+ dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
8486 }
8587 }
8688 }
0 commit comments