@@ -39,7 +39,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
3939
4040
4141template <typename T>
42- static __global__ void cpy_flt_transpose (char * cx, char * cdst_direct, , const int ne,
42+ static __global__ void cpy_flt_transpose (const char * cx, char * cdst_direct, const int ne,
4343 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4444 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
4545 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
@@ -58,22 +58,31 @@ static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const i
5858 int tx = blockIdx .y * TILE_DIM + threadIdx .x ; // transpose block offset
5959 int ty = blockIdx .x * TILE_DIM + threadIdx .y ;
6060
61- __shared__ T tile[TILE_DIM * TILE_DIM];
61+ // __shared__ T tile[TILE_DIM * TILE_DIM];
62+ __shared__ T tile[TILE_DIM][TILE_DIM];
6263
6364 for (int i = 0 ; i < BLOCK_NM; ++i){
6465 const unsigned int imat = blockIdx .z * BLOCK_NM + i;
6566 if (imat < nmat){
6667 for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
6768 const unsigned int idx = (y+j)*width + x;
68- if (idx < n)
69- tile[threadIdx .y +j][threadIdx .x ] = src[imat*n + idx];
69+ if (idx < n){
70+ const int row = threadIdx .y +j;
71+ const int col = threadIdx .x ^ row;
72+ // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
73+ tile[row][col] = src[imat*n + idx];
74+ }
7075 }
7176 __syncthreads ();
7277
7378 for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
7479 const unsigned int idx = (ty+j)*width + tx;
75- if (idx < n)
76- dst[imat*n + idx] = tile[threadIdx .x ][threadIdx .y + j];
80+ if (idx < n){
81+ // const int row = threadIdx.x;
82+ const int col = (threadIdx .y +j) ^ threadIdx .x ;
83+ // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
84+ dst[imat*n + idx] = tile[threadIdx .x ][col];
85+ }
7786 }
7887 }
7988 }
@@ -180,30 +189,33 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
180189#endif
181190}
182191
183- template <typename src_t , typename dst_t >
192+ template <typename src_t , typename dst_t , bool transpose = false >
184193static void ggml_cpy_flt_cuda (
185194 const char * cx, char * cdst, const int ne,
186195 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
187196 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
188197 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
189- if constexpr (std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
190- std::is_same_v<src_t , float > && std::is_same_v<dst_t , float >
191- ){
192- if (ne00 == ne11 && ne01 = ne10 && nb00 == nb11 && nb10 == nb01){ // transpose
198+ if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
199+ std::is_same_v<src_t , float > && std::is_same_v<dst_t , float >)
200+ && transpose){
201+ // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
202+ // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
203+ // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
204+ // if (transpose) { //transpose
205+ // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
193206 dim3 dimGrid ( (ne00 + TILE_DIM - 1 ) / TILE_DIM,
194207 (ne01 + TILE_DIM - 1 ) / TILE_DIM,
195208 (ne/(ne00*ne01) + BLOCK_NM - 1 ) / BLOCK_NM );
196209 dim3 dimBlock (TILE_DIM, BLOCK_ROWS, 1 );
197- cpy_flt_transpose<cpy_1_flt<dst_t ><<<dimGrid, dimBlock, 0 , stream>>>
198- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
199- } else { // other
200- cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
201- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
202- }
203- } else {
210+ cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
211+ } else { // other
204212 cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
205213 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
206214 }
215+ // } else{
216+ // cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
217+ // (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
218+ // }
207219}
208220
209221static void ggml_cpy_f32_q8_0_cuda (
@@ -389,7 +401,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
389401 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
390402 }
391403 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
392- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
404+ if (src1->op_params [10 ] == 999 ){
405+ ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406+ } else {
407+ ggml_cpy_flt_cuda<float , float , false > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
408+ }
393409 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
394410 ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
395411 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
@@ -420,7 +436,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
420436 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
421437 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
422438 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
423- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439+ if (src1->op_params [10 ] == 999 ){
440+ ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441+ } else {
442+ ggml_cpy_flt_cuda<half, half, false > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
443+ }
424444 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
425445 ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
426446 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
0 commit comments