@@ -36,47 +36,42 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3636}
3737
3838template <typename T>
39- static __global__ void cpy_flt_transpose (const char * cx, char * cdst_direct , const int ne,
39+ static __global__ void cpy_flt_transpose (const char * cx, char * cdst , const int ne,
4040 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4141 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
42- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
43-
44- char * cdst = (cdst_indirect != nullptr ) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
42+ const int nb12, const int nb13) {
4543
4644 const T* src = reinterpret_cast <const T*>(cx);
4745 T* dst = reinterpret_cast <T*>(cdst);
4846
4947 const int64_t nmat = ne / (ne00 * ne01);
5048 const int64_t n = ne00 * ne01;
51- int width = ne01;
52- int height = ne00;
53- int x = blockIdx .x * TILE_DIM + threadIdx .x ;
54- int y = blockIdx .y * TILE_DIM + threadIdx .y ;
55- int tx = blockIdx .y * TILE_DIM + threadIdx .x ; // transpose block offset
56- int ty = blockIdx .x * TILE_DIM + threadIdx .y ;
5749
58- __shared__ T tile[TILE_DIM][TILE_DIM];
50+ int x = blockIdx .x * CUDA_CPY_TILE_DIM + threadIdx .x ;
51+ int y = blockIdx .y * CUDA_CPY_TILE_DIM + threadIdx .y ;
52+ int tx = blockIdx .y * CUDA_CPY_TILE_DIM + threadIdx .x ; // transpose block offset
53+ int ty = blockIdx .x * CUDA_CPY_TILE_DIM + threadIdx .y ;
54+
55+ __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
5956
60- for (int i = 0 ; i < BLOCK_NM ; ++i){
57+ for (int i = 0 ; i < CUDA_CPY_BLOCK_NM ; ++i){
6158
62- const unsigned int imat = blockIdx .z * BLOCK_NM + i;
59+ const unsigned int imat = blockIdx .z * CUDA_CPY_BLOCK_NM + i;
6360 if (imat >= nmat)
6461 break ;
65- for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
66- if (x < width && y + j < height){
67- const unsigned int idx = (y+j)*width + x;
62+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
63+ if (x < ne01 && y + j < ne00){
6864 const int row = threadIdx .y +j;
6965 const int col = threadIdx .x ^ row;
70- tile[row][col] = src[imat*n + idx ];
66+ tile[row][col] = src[imat*n + (y+j)*ne01 + x ];
7167 }
7268 }
7369 __syncthreads ();
7470
75- for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS){
76- if (ty + j < width && tx < height){
77- const unsigned int idx = (ty+j)*height + tx;
71+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
72+ if (ty + j < ne01 && tx < ne00){
7873 const int col = (threadIdx .y +j) ^ threadIdx .x ;
79- dst[imat*n + idx ] = tile[threadIdx .x ][col];
74+ dst[imat*n + (ty+j)*ne00 + tx ] = tile[threadIdx .x ][col];
8075 }
8176 }
8277 }
@@ -188,14 +183,16 @@ static void ggml_cpy_flt_cuda(
188183 const char * cx, char * cdst, const int ne,
189184 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
190185 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
186+
191187 if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
192188 std::is_same_v<src_t , float > && std::is_same_v<dst_t , float >)
193- && transpose){
194- dim3 dimGrid ( (ne01 + TILE_DIM - 1 ) / TILE_DIM,
195- (ne00 + TILE_DIM - 1 ) / TILE_DIM,
196- (ne/(ne00*ne01) + BLOCK_NM - 1 ) / BLOCK_NM );
197- dim3 dimBlock (TILE_DIM, BLOCK_ROWS, 1 );
198- cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
189+ && transpose){ // transpose
190+
191+ dim3 dimGrid ( (ne01 + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
192+ (ne00 + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
193+ (ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1 ) / CUDA_CPY_BLOCK_NM );
194+ dim3 dimBlock (CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1 );
195+ cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
199196 } else { // other
200197 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
201198 cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
@@ -378,7 +375,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
378375 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
379376 }
380377 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
381- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
378+ if (src0->op == GGML_OP_TRANSPOSE){
379+ ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
380+ } else {
381+ ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
382+ }
382383 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
383384 if (contiguous_srcs) {
384385 ggml_cpy_flt_contiguous_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -417,7 +418,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
417418 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
418419 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
419420 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
420- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
421+ if (src0->op == GGML_OP_TRANSPOSE){
422+ ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
423+ } else {
424+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
425+ }
421426 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
422427 if (contiguous_srcs) {
423428 ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
0 commit comments