@@ -52,7 +52,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co
5252 const int64_t nmat = ne /(ne00 * ne01);
5353 const int64_t n = ne00 * ne01;
5454 // const int64_t n = ne01 * ne02;
55- int width = gridDim . x * TILE_DIM ;
55+ int width = ne01 ;
5656 int x = blockIdx .x * TILE_DIM + threadIdx .x ;
5757 int y = blockIdx .y * TILE_DIM + threadIdx .y ;
5858 int tx = blockIdx .y * TILE_DIM + threadIdx .x ; // transpose block offset
@@ -194,21 +194,22 @@ static void ggml_cpy_flt_cuda(
194194 const char * cx, char * cdst, const int ne,
195195 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
196196 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
197- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
198- if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
197+
198+ if constexpr ((std::is_same_v<src_t , half> && std::is_same_v<dst_t , half> ||
199199 std::is_same_v<src_t , float > && std::is_same_v<dst_t , float >)
200200 && transpose){
201201 // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
202202 // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
203203 // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
204204 // if (transpose) { //transpose
205205 // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
206- dim3 dimGrid ( (ne00 + TILE_DIM - 1 ) / TILE_DIM,
207- (ne01 + TILE_DIM - 1 ) / TILE_DIM,
206+ dim3 dimGrid ( (ne01 + TILE_DIM - 1 ) / TILE_DIM,
207+ (ne00 + TILE_DIM - 1 ) / TILE_DIM,
208208 (ne/(ne00*ne01) + BLOCK_NM - 1 ) / BLOCK_NM );
209209 dim3 dimBlock (TILE_DIM, BLOCK_ROWS, 1 );
210210 cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
211211 } else { // other
212+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
212213 cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
213214 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
214215 }
@@ -401,7 +402,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
401402 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
402403 }
403404 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
404- if (src1 ->op_params [10 ] == 999 ){
405+ if (src0 ->op_params [10 ] == 999 ){
405406 ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406407 } else {
407408 ggml_cpy_flt_cuda<float , float , false > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@@ -436,7 +437,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
436437 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
437438 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
438439 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
439- if (src1 ->op_params [10 ] == 999 ){
440+ if (src0 ->op_params [10 ] == 999 ){
440441 ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441442 } else {
442443 ggml_cpy_flt_cuda<half, half, false > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
0 commit comments