From 5afac4df041edaf5c68b181ceaef449fb91870ec Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 28 Oct 2025 22:36:50 -0400 Subject: [PATCH 01/18] WIP --- ggml/src/ggml-cuda/cpy.cu | 66 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c5821acbdeb8a..02cb9d7a895b5 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -35,6 +35,53 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + + const T* src = reinterpret_cast(cx); + T* dst = reinterpret_cast(cdst); + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + int width = ne01; + int height = ne00; + int x = blockIdx.x * TILE_DIM + threadIdx.x; + int y = blockIdx.y * TILE_DIM + threadIdx.y; + int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * TILE_DIM + threadIdx.y; + + __shared__ T tile[TILE_DIM][TILE_DIM]; + + for(int i = 0; i < BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + if(x < width && y + j < height){ + const unsigned int idx = (y+j)*width + x; + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + idx]; + } + } + __syncthreads(); + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + if(ty + j < width && tx < height){ + const unsigned int idx = (ty+j)*height + tx; + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + idx] = tile[threadIdx.x][col]; + } + } + } +} + static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -136,15 +183,24 @@ cudaStream_t stream) { (cx, cdst, ne); } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + if constexpr ((std::is_same_v && std::is_same_v || + std::is_same_v && std::is_same_v) + && transpose){ + dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM, + (ne00 + TILE_DIM - 1) / TILE_DIM, + (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); + dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); + cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + } else{ // other + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } } static void ggml_cpy_f32_q8_0_cuda( From 30d4607117d27f680b436018b9e256a285fae5b3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 08:41:44 -0400 Subject: [PATCH 02/18] added a cpy kernel specific to transposed tensor which uses smem to avoid uncoalesced access; test cases also added shwoing improved memory bandwidth --- ggml/src/ggml-cuda/cpy.cu | 63 ++++++++++++++++++++------------------ ggml/src/ggml-cuda/cpy.cuh | 3 ++ tests/test-backend-ops.cpp | 24 +++++++++++++-- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 02cb9d7a895b5..14cfb26a3b050 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -36,47 +36,42 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { - - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int nb12, const int nb13) { const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; - int width = ne01; - int height = ne00; - int x = blockIdx.x * TILE_DIM + threadIdx.x; - int y = blockIdx.y * TILE_DIM + threadIdx.y; - int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset - int ty = blockIdx.x * TILE_DIM + threadIdx.y; - __shared__ T tile[TILE_DIM][TILE_DIM]; + int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; + int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; + int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y; + + __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; - for(int i = 0; i < BLOCK_NM; ++i){ + for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ - const unsigned int imat = blockIdx.z * BLOCK_NM + i; + const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; if(imat >= nmat) break; - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - if(x < width && y + j < height){ - const unsigned int idx = (y+j)*width + x; + for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; const int col = threadIdx.x ^ row; - tile[row][col] = src[imat*n + idx]; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; } } __syncthreads(); - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - if(ty + j < width && tx < height){ - const unsigned int idx = (ty+j)*height + tx; + for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ const int col = (threadIdx.y+j) ^ threadIdx.x; - dst[imat*n + idx] = tile[threadIdx.x][col]; + dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; } } } @@ -188,14 +183,16 @@ static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + if constexpr ((std::is_same_v && std::is_same_v || std::is_same_v && std::is_same_v) - && transpose){ - dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM, - (ne00 + TILE_DIM - 1) / TILE_DIM, - (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); - dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); - cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + && transpose){ //transpose + + dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM ); + dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1); + cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else{ // other const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> @@ -378,7 +375,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if(src0->op == GGML_OP_TRANSPOSE){ + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); @@ -417,7 +418,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if(src0->op == GGML_OP_TRANSPOSE){ + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index a7a87d8fcfb7e..f4f87f8de773a 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -1,6 +1,9 @@ #include "common.cuh" #define CUDA_CPY_BLOCK_SIZE 64 +#define CUDA_CPY_TILE_DIM 32 +#define CUDA_CPY_BLOCK_ROWS 8 +#define CUDA_CPY_BLOCK_NM 8 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aee1730137900..29fb07bae098c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2509,6 +2509,7 @@ struct test_cpy : public test_case { const std::array permute_dst; bool _src_use_permute; bool _dst_use_permute; + bool _src_transpose; std::string vars() override { return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); @@ -2549,10 +2550,12 @@ struct test_cpy : public test_case { test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}, std::array permute_src = {0, 0, 0, 0}, - std::array permute_dst = {0, 0, 0, 0}) + std::array permute_dst = {0, 0, 0, 0}, + bool transpose_src = false) : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst), _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0), - _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {} + _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0), + _src_transpose(transpose_src){} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2564,6 +2567,11 @@ struct test_cpy : public test_case { ggml_set_name(src, "src_permuted"); } + if (_src_transpose) { + src = ggml_transpose(ctx, src); + ggml_set_name(src, "src_transposed"); + } + ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); ggml_set_name(dst, "dst"); @@ -6513,6 +6521,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); @@ -7244,6 +7254,16 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); + + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From d3bdcf84ff19572d8c03ccca3c637a6ce697927e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 09:27:08 -0400 Subject: [PATCH 03/18] added BF16 support --- ggml/src/ggml-cuda/cpy.cu | 11 ++++++----- tests/test-backend-ops.cpp | 3 +++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 14cfb26a3b050..b335acdb3d2c9 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -184,10 +184,7 @@ static void ggml_cpy_flt_cuda( const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - if constexpr ((std::is_same_v && std::is_same_v || - std::is_same_v && std::is_same_v) - && transpose){ //transpose - + if (transpose){ //transpose dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, (ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, (ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM ); @@ -436,7 +433,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if(src0->op == GGML_OP_TRANSPOSE){ + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 29fb07bae098c..5d6f611de43bc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6523,6 +6523,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); @@ -7258,10 +7259,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From 18818a2898bf804f3e1fafbd5aca3e652562ac96 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 16:57:03 -0400 Subject: [PATCH 04/18] more strict check to make sure src0 is a transpose --- ggml/src/ggml-cuda/cpy.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index b335acdb3d2c9..48070a47bf7b2 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -372,7 +372,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -415,7 +415,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -433,7 +433,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); From 35daa02a03e80d1ecf0a2c648b7b67e5c5b76ef6 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 30 Oct 2025 15:20:44 -0400 Subject: [PATCH 05/18] reformulated to handle more complicated transpose cases --- ggml/src/ggml-cuda/cpy.cu | 236 +++++++++++++++++++++++++++++++------ ggml/src/ggml-cuda/cpy.cuh | 3 - 2 files changed, 200 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index b335acdb3d2c9..0683aff4e5adb 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -7,6 +7,8 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); +const int CUDA_CPY_TILE_DIM = 16; + template static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -35,43 +37,153 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, +// template +// static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, +// const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, +// const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, +// const int nb12, const int nb13) { + +// const T* src = reinterpret_cast(cx); +// T* dst = reinterpret_cast(cdst); + +// const int64_t nmat = ne / (ne00 * ne01); +// const int64_t n = ne00 * ne01; + +// int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; +// int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; +// int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset +// int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y; + +// __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; + +// for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ + +// const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; +// if(imat >= nmat) +// break; +// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ +// if(x < ne01 && y + j < ne00){ +// const int row = threadIdx.y+j; +// const int col = threadIdx.x ^ row; +// tile[row][col] = src[imat*n + (y+j)*ne01 + x]; +// } +// } +// __syncthreads(); + +// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ +// if(ty + j < ne01 && tx < ne00){ +// const int col = (threadIdx.y+j) ^ threadIdx.x; +// dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; +// } +// } +// } +// } + + +template +static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13) { const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); - - const int64_t nmat = ne / (ne00 * ne01); - const int64_t n = ne00 * ne01; + // nidx[0] inner most + // nidx[1] middle + // nidx[2] outer most + // const int64_t nmat = ne / (ne00 * ne01); + // const int64_t n = ne00 * ne01; + // const int64_t ne00 = ne0[nidx[0]]; + // const int64_t ne01 = ne0[nidx[1]]; + // const int64_t ne02 = ne0[nidx[2]]; + const int64_t n0 = ne00 * ne01; + // const int64_t ne10 = ne1[0]; + // const int64_t ne11 = ne1[1]; + // const int64_t ne12 = ne1[2]; + const int64_t n1 = ne10 * ne11; int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; - int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset - int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y; - - __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; - - for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ - - const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; - if(imat >= nmat) - break; - for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ - if(x < ne01 && y + j < ne00){ - const int row = threadIdx.y+j; - const int col = threadIdx.x ^ row; - tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + int z = blockIdx.z * CUDA_CPY_TILE_DIM; + // int tx = blockIdx.x * CUDA_CPY_TILE_DIM[ntidx[0]] + threadIdx.x; // transpose block offset + // int ty = blockIdx.y * CUDA_CPY_TILE_DIM[ntidx[1]] + threadIdx.y; + // int tz = blockIdx.z * CUDA_CPY_TILE_DIM[ntidx[2]]; + + __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; + + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // for (int j = 0; j < CUDA_CPY_TILE_DIM[1]; ++j){ + if(x < ne00 && y < ne01 && z + k < ne02){ + // const int row = threadIdx.y+j; + // const int col = threadIdx.x ^ row; + const int row = threadIdx.y; + const int col = threadIdx.x; + tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x]; + } + // } + } + __syncthreads(); + + if(zero_at == 2){ + int tx = blockIdx.z * CUDA_CPY_TILE_DIM; + if(one_at == 0){ + int ty = blockIdx.x * CUDA_CPY_TILE_DIM; + int tz = blockIdx.y * CUDA_CPY_TILE_DIM; + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // const int row = threadIdx.y; + // const int col = threadIdx.x; + // const int col = (threadIdx.y+j) ^ threadIdx.x; + if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ + dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][k][threadIdx.y]; + } + } + } else{ // one at 1 + int tz = blockIdx.x * CUDA_CPY_TILE_DIM; + int ty = blockIdx.y * CUDA_CPY_TILE_DIM; + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // const int row = threadIdx.y; + // const int col = threadIdx.x; + // const int col = (threadIdx.y+j) ^ threadIdx.x; + if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ + dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][threadIdx.y][k]; + } } } - __syncthreads(); - - for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ - if(ty + j < ne01 && tx < ne00){ - const int col = (threadIdx.y+j) ^ threadIdx.x; - dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; + } else if(zero_at == 1){ + int tx = blockIdx.y * CUDA_CPY_TILE_DIM; + if(one_at == 0){ + int ty = blockIdx.x * CUDA_CPY_TILE_DIM; + int tz = blockIdx.z * CUDA_CPY_TILE_DIM; + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // const int row = threadIdx.y; + // const int col = threadIdx.x; + // const int col = (threadIdx.y+j) ^ threadIdx.x; + if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ + dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[k][threadIdx.x][threadIdx.y]; + } + } + } else { // one at 2 + int ty = blockIdx.z * CUDA_CPY_TILE_DIM; + int tz = blockIdx.x * CUDA_CPY_TILE_DIM; + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // const int row = threadIdx.y; + // const int col = threadIdx.x; + // const int col = (threadIdx.y+j) ^ threadIdx.x; + if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ + dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][threadIdx.x][k]; + } + } + } + } else{ // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous + int tx = blockIdx.x * CUDA_CPY_TILE_DIM; + int ty = blockIdx.z * CUDA_CPY_TILE_DIM; + int tz = blockIdx.y * CUDA_CPY_TILE_DIM; + for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ + // const int row = threadIdx.y; + // const int col = threadIdx.x; + // const int col = (threadIdx.y+j) ^ threadIdx.x; + if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ + dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][k][threadIdx.x]; } } } @@ -178,18 +290,67 @@ cudaStream_t stream) { (cx, cdst, ne); } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - if (transpose){ //transpose - dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM ); - dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1); - cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + if (coalesced){ //transpose + // printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); + // printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12); + // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); + // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13); + GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed + std::vector> v; + v.emplace_back(std::make_tuple(nb00, ne00, 0)); + v.emplace_back(std::make_tuple(nb01, ne01, 1)); + v.emplace_back(std::make_tuple(nb02, ne02, 2)); + std::sort(v.begin(), v.end(), + [](auto &a, auto &b) { + return std::get<0>(a) < std::get<0>(b); + }); + const int ne0_new = std::get<1>(v[0]); + const int ne1_new = std::get<1>(v[1]); + const int ne2_new = std::get<1>(v[2]); + int nidx[3]; + nidx[0] = std::get<2>(v[0]); + nidx[1] = std::get<2>(v[1]); + nidx[2] = std::get<2>(v[2]); + // printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]); + // printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new); + const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0); + const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0); + + dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); + dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1); + if(zero_at == 2){ + if(one_at == 1){ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + }else{ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + } + } else if(zero_at == 1){ + if(one_at == 2){ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + }else{ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + } + } else{ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + } } else{ // other const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> @@ -372,7 +533,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + // printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -415,7 +577,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + // printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -433,7 +596,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - if(src0->op == GGML_OP_TRANSPOSE){ + if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + // printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index f4f87f8de773a..a7a87d8fcfb7e 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -1,9 +1,6 @@ #include "common.cuh" #define CUDA_CPY_BLOCK_SIZE 64 -#define CUDA_CPY_TILE_DIM 32 -#define CUDA_CPY_BLOCK_ROWS 8 -#define CUDA_CPY_BLOCK_NM 8 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); From d2ec251f4802353758a7f6ff17583b0e18739406 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 30 Oct 2025 16:43:10 -0400 Subject: [PATCH 06/18] bring back 2D transpose for higher performance --- ggml/src/ggml-cuda/cpy.cu | 184 ++++++++++++++++++++----------------- tests/test-backend-ops.cpp | 29 +++--- 2 files changed, 113 insertions(+), 100 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 0683aff4e5adb..16d2787555b94 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,6 +8,9 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); const int CUDA_CPY_TILE_DIM = 16; +const int CUDA_CPY_TILE_DIM_2D = 32; +const int CUDA_CPY_BLOCK_NM = 8; +const int CUDA_CPY_BLOCK_ROWS = 8; template static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, @@ -37,47 +40,47 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -// template -// static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, -// const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, -// const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, -// const int nb12, const int nb13) { - -// const T* src = reinterpret_cast(cx); -// T* dst = reinterpret_cast(cdst); - -// const int64_t nmat = ne / (ne00 * ne01); -// const int64_t n = ne00 * ne01; - -// int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; -// int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; -// int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset -// int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y; - -// __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; - -// for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ - -// const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; -// if(imat >= nmat) -// break; -// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ -// if(x < ne01 && y + j < ne00){ -// const int row = threadIdx.y+j; -// const int col = threadIdx.x ^ row; -// tile[row][col] = src[imat*n + (y+j)*ne01 + x]; -// } -// } -// __syncthreads(); - -// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ -// if(ty + j < ne01 && tx < ne00){ -// const int col = (threadIdx.y+j) ^ threadIdx.x; -// dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; -// } -// } -// } -// } +template +static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { + + const T* src = reinterpret_cast(cx); + T* dst = reinterpret_cast(cdst); + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; + int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + + __shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D]; + + for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + } + } + __syncthreads(); + + for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; + } + } + } +} template @@ -302,54 +305,63 @@ static void ggml_cpy_flt_cuda( // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13); GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - std::vector> v; - v.emplace_back(std::make_tuple(nb00, ne00, 0)); - v.emplace_back(std::make_tuple(nb01, ne01, 1)); - v.emplace_back(std::make_tuple(nb02, ne02, 2)); - std::sort(v.begin(), v.end(), - [](auto &a, auto &b) { - return std::get<0>(a) < std::get<0>(b); - }); - const int ne0_new = std::get<1>(v[0]); - const int ne1_new = std::get<1>(v[1]); - const int ne2_new = std::get<1>(v[2]); - int nidx[3]; - nidx[0] = std::get<2>(v[0]); - nidx[1] = std::get<2>(v[1]); - nidx[2] = std::get<2>(v[2]); - // printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]); - // printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new); - const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0); - const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0); - - dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); - dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1); - if(zero_at == 2){ - if(one_at == 1){ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - }else{ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - } - } else if(zero_at == 1){ - if(one_at == 2){ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - }else{ - cpy_flt_coalesced<<>>( + if(ne02 == 1) { + dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); + cpy_flt_transpose<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } else{ + std::vector> v; + v.emplace_back(std::make_tuple(nb00, ne00, 0)); + v.emplace_back(std::make_tuple(nb01, ne01, 1)); + v.emplace_back(std::make_tuple(nb02, ne02, 2)); + std::sort(v.begin(), v.end(), + [](auto &a, auto &b) { + return std::get<0>(a) < std::get<0>(b); + }); + const int ne0_new = std::get<1>(v[0]); + const int ne1_new = std::get<1>(v[1]); + const int ne2_new = std::get<1>(v[2]); + int nidx[3]; + nidx[0] = std::get<2>(v[0]); + nidx[1] = std::get<2>(v[1]); + nidx[2] = std::get<2>(v[2]); + // printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]); + // printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new); + const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0); + const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0); + + dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); + dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1); + if(zero_at == 2){ + if(one_at == 1){ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + }else{ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + } + } else if(zero_at == 1){ + if(one_at == 2){ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + }else{ + cpy_flt_coalesced<<>>( + cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, + nb10, nb11, nb12, nb13); + } + } else{ + cpy_flt_coalesced<<>>( cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } - } else{ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); } } else{ // other const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5d6f611de43bc..90eccfe879d6c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2512,7 +2512,7 @@ struct test_cpy : public test_case { bool _src_transpose; std::string vars() override { - return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); + return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose); } double max_nmse_err() override { @@ -7249,22 +7249,23 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From 38096455645d76afd1874267cfe926403b39f694 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 30 Oct 2025 16:44:23 -0400 Subject: [PATCH 07/18] allow build on windows --- ggml/src/ggml-cuda/cpy.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 16d2787555b94..962527515adf5 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,3 +1,4 @@ +#include #include "cpy.cuh" #include "dequantize.cuh" #include "cpy-utils.cuh" From c36b70b115b9bc806f1b9a41f26913858a36d720 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 30 Oct 2025 19:43:30 -0400 Subject: [PATCH 08/18] tranpose copy more shapes --- ggml/src/ggml-cuda/cpy.cu | 16 +++++++++------- tests/test-backend-ops.cpp | 26 +++++++++++++++----------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 962527515adf5..333855ec60a9c 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -65,7 +65,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; if(imat >= nmat) break; - for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){ if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; const int col = threadIdx.x ^ row; @@ -74,7 +74,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int } __syncthreads(); - for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){ + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){ if(ty + j < ne01 && tx < ne00){ const int col = (threadIdx.y+j) ^ threadIdx.x; dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; @@ -305,8 +305,8 @@ static void ggml_cpy_flt_cuda( // printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12); // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13); - GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - if(ne02 == 1) { + // GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed + if( nb00 < nb02 && nb02 < nb03) { dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); @@ -534,6 +534,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && + (src0->ne[3] == 1 || (src0->nb[2] < src0->nb[3] && src0->nb[0] < src0->nb[2])); if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); @@ -546,7 +548,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + if(can_be_transposed){ // printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { @@ -590,7 +592,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + if(can_be_transposed){ // printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { @@ -609,7 +611,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){ + if(can_be_transposed){ // printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 90eccfe879d6c..61798497d4005 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6521,9 +6521,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); @@ -7256,16 +7260,16 @@ static std::vector> make_test_cases_perf() { // test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From 90fd9920b049df7a32a5495387c4c47d5c9134f3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 30 Oct 2025 20:08:50 -0400 Subject: [PATCH 09/18] minor tweak --- ggml/src/ggml-cuda/cpy.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 333855ec60a9c..7ae198135a664 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -306,7 +306,9 @@ static void ggml_cpy_flt_cuda( // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13); // GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - if( nb00 < nb02 && nb02 < nb03) { + if( nb00 < nb02 && nb02 <= nb03 ) { + // printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); + // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); @@ -314,6 +316,8 @@ static void ggml_cpy_flt_cuda( cpy_flt_transpose<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else{ + // printf("b %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); + // printf("d %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); std::vector> v; v.emplace_back(std::make_tuple(nb00, ne00, 0)); v.emplace_back(std::make_tuple(nb01, ne01, 1)); @@ -535,7 +539,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && - (src0->ne[3] == 1 || (src0->nb[2] < src0->nb[3] && src0->nb[0] < src0->nb[2])); + (src0->ne[3] == 1 || (src0->nb[2] <= src0->nb[3] && src0->nb[0] < src0->nb[2])); if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); From d49232c65bde736467561a571d0fb881df41fdc3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 31 Oct 2025 10:04:31 -0400 Subject: [PATCH 10/18] final clean up --- ggml/src/ggml-cuda/cpy.cu | 37 +++++-------------------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 7ae198135a664..7178bc697100b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -92,31 +92,17 @@ static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); - // nidx[0] inner most - // nidx[1] middle - // nidx[2] outer most - // const int64_t nmat = ne / (ne00 * ne01); - // const int64_t n = ne00 * ne01; - // const int64_t ne00 = ne0[nidx[0]]; - // const int64_t ne01 = ne0[nidx[1]]; - // const int64_t ne02 = ne0[nidx[2]]; + const int64_t n0 = ne00 * ne01; - // const int64_t ne10 = ne1[0]; - // const int64_t ne11 = ne1[1]; - // const int64_t ne12 = ne1[2]; const int64_t n1 = ne10 * ne11; int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; int z = blockIdx.z * CUDA_CPY_TILE_DIM; - // int tx = blockIdx.x * CUDA_CPY_TILE_DIM[ntidx[0]] + threadIdx.x; // transpose block offset - // int ty = blockIdx.y * CUDA_CPY_TILE_DIM[ntidx[1]] + threadIdx.y; - // int tz = blockIdx.z * CUDA_CPY_TILE_DIM[ntidx[2]]; __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // for (int j = 0; j < CUDA_CPY_TILE_DIM[1]; ++j){ if(x < ne00 && y < ne01 && z + k < ne02){ // const int row = threadIdx.y+j; // const int col = threadIdx.x ^ row; @@ -124,7 +110,6 @@ static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int const int col = threadIdx.x; tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x]; } - // } } __syncthreads(); @@ -301,14 +286,8 @@ static void ggml_cpy_flt_cuda( const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { if (coalesced){ //transpose - // printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); - // printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12); - // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); - // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13); // GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed if( nb00 < nb02 && nb02 <= nb03 ) { - // printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); - // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); @@ -316,8 +295,6 @@ static void ggml_cpy_flt_cuda( cpy_flt_transpose<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else{ - // printf("b %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02); - // printf("d %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03); std::vector> v; v.emplace_back(std::make_tuple(nb00, ne00, 0)); v.emplace_back(std::make_tuple(nb01, ne01, 1)); @@ -333,15 +310,14 @@ static void ggml_cpy_flt_cuda( nidx[0] = std::get<2>(v[0]); nidx[1] = std::get<2>(v[1]); nidx[2] = std::get<2>(v[2]); - // printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]); - // printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new); const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0); const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0); - dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); + dim3 dimGrid((ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, + (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1); + if(zero_at == 2){ if(one_at == 1){ cpy_flt_coalesced<<>>( @@ -553,7 +529,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { if(can_be_transposed){ - // printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -597,7 +572,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { if(can_be_transposed){ - // printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -616,7 +590,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { if(can_be_transposed){ - // printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1)); ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); From 8dbb4c7df101fa19e1cf45a4055ffb4f0f53f62f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 07:49:03 -0500 Subject: [PATCH 11/18] restore some test cases --- tests/test-backend-ops.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 61798497d4005..19c179ae86ac9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7253,12 +7253,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); - + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0})); From 28e5cf6edf820de67b8e0e29bc3437c47872edf3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 14:42:00 -0500 Subject: [PATCH 12/18] keep only the kernel for true tranposed case; updated with review suggestions --- ggml/src/ggml-cuda/cpy.cu | 220 ++++++++------------------------------ 1 file changed, 45 insertions(+), 175 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 7178bc697100b..68cb030f0f21f 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,10 +8,9 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -const int CUDA_CPY_TILE_DIM = 16; -const int CUDA_CPY_TILE_DIM_2D = 32; -const int CUDA_CPY_BLOCK_NM = 8; -const int CUDA_CPY_BLOCK_ROWS = 8; +const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks +const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available +const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, @@ -53,131 +52,41 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; - int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; - int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset - int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; + const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset + const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; __shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D]; - for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){ +#pragma unroll + for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; - if(imat >= nmat) + if (imat >= nmat) break; - for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){ + +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; - const int col = threadIdx.x ^ row; + const int col = threadIdx.x ^ row; //swizzling to avoid bank conflicts tile[row][col] = src[imat*n + (y+j)*ne01 + x]; } } + __syncthreads(); - for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){ - if(ty + j < ne01 && tx < ne00){ - const int col = (threadIdx.y+j) ^ threadIdx.x; +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { + if (ty + j < ne01 && tx < ne00) { + const int col = (threadIdx.y+j) ^ threadIdx.x; //swizzling to avoid bank conflicts dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; } } } } - -template -static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - - const T* src = reinterpret_cast(cx); - T* dst = reinterpret_cast(cdst); - - const int64_t n0 = ne00 * ne01; - const int64_t n1 = ne10 * ne11; - - int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x; - int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y; - int z = blockIdx.z * CUDA_CPY_TILE_DIM; - - __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM]; - - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - if(x < ne00 && y < ne01 && z + k < ne02){ - // const int row = threadIdx.y+j; - // const int col = threadIdx.x ^ row; - const int row = threadIdx.y; - const int col = threadIdx.x; - tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x]; - } - } - __syncthreads(); - - if(zero_at == 2){ - int tx = blockIdx.z * CUDA_CPY_TILE_DIM; - if(one_at == 0){ - int ty = blockIdx.x * CUDA_CPY_TILE_DIM; - int tz = blockIdx.y * CUDA_CPY_TILE_DIM; - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // const int row = threadIdx.y; - // const int col = threadIdx.x; - // const int col = (threadIdx.y+j) ^ threadIdx.x; - if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ - dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][k][threadIdx.y]; - } - } - } else{ // one at 1 - int tz = blockIdx.x * CUDA_CPY_TILE_DIM; - int ty = blockIdx.y * CUDA_CPY_TILE_DIM; - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // const int row = threadIdx.y; - // const int col = threadIdx.x; - // const int col = (threadIdx.y+j) ^ threadIdx.x; - if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ - dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][threadIdx.y][k]; - } - } - } - } else if(zero_at == 1){ - int tx = blockIdx.y * CUDA_CPY_TILE_DIM; - if(one_at == 0){ - int ty = blockIdx.x * CUDA_CPY_TILE_DIM; - int tz = blockIdx.z * CUDA_CPY_TILE_DIM; - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // const int row = threadIdx.y; - // const int col = threadIdx.x; - // const int col = (threadIdx.y+j) ^ threadIdx.x; - if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ - dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[k][threadIdx.x][threadIdx.y]; - } - } - } else { // one at 2 - int ty = blockIdx.z * CUDA_CPY_TILE_DIM; - int tz = blockIdx.x * CUDA_CPY_TILE_DIM; - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // const int row = threadIdx.y; - // const int col = threadIdx.x; - // const int col = (threadIdx.y+j) ^ threadIdx.x; - if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ - dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][threadIdx.x][k]; - } - } - } - } else{ // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous - int tx = blockIdx.x * CUDA_CPY_TILE_DIM; - int ty = blockIdx.z * CUDA_CPY_TILE_DIM; - int tz = blockIdx.y * CUDA_CPY_TILE_DIM; - for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){ - // const int row = threadIdx.y; - // const int col = threadIdx.x; - // const int col = (threadIdx.y+j) ^ threadIdx.x; - if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){ - dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][k][threadIdx.x]; - } - } - } -} - static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -279,72 +188,34 @@ cudaStream_t stream) { (cx, cdst, ne); } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - if (coalesced){ //transpose - // GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - if( nb00 < nb02 && nb02 <= nb03 ) { - dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); - dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_flt_transpose<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); - } else{ - std::vector> v; - v.emplace_back(std::make_tuple(nb00, ne00, 0)); - v.emplace_back(std::make_tuple(nb01, ne01, 1)); - v.emplace_back(std::make_tuple(nb02, ne02, 2)); - std::sort(v.begin(), v.end(), - [](auto &a, auto &b) { - return std::get<0>(a) < std::get<0>(b); - }); - const int ne0_new = std::get<1>(v[0]); - const int ne1_new = std::get<1>(v[1]); - const int ne2_new = std::get<1>(v[2]); - int nidx[3]; - nidx[0] = std::get<2>(v[0]); - nidx[1] = std::get<2>(v[1]); - nidx[2] = std::get<2>(v[2]); - const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0); - const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0); - - dim3 dimGrid((ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM, - (ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM); - dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1); - - if(zero_at == 2){ - if(one_at == 1){ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - }else{ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - } - } else if(zero_at == 1){ - if(one_at == 2){ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - }else{ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - } - } else{ - cpy_flt_coalesced<<>>( - cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, nb13); - } + if (transposed) { + GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed + int ne00n, ne01n, ne02n; + if (nb00 < nb02) { + ne00n = ne00; + ne01n = ne01; + ne02n = ne02; + } else if (nb00 > nb02) { + ne00n = ne00; + ne01n = ne01*ne02; + ne02n = 1; + } else { + GGML_ASSERT(false); } - } else{ // other + + dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); + cpy_flt_transpose<<>> + (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } else { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); @@ -514,8 +385,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && - (src0->ne[3] == 1 || (src0->nb[2] <= src0->nb[3] && src0->nb[0] < src0->nb[2])); + const bool can_be_transposed = nb01 == ggml_element_size(src0) && src0->ne[3] == 1; if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); @@ -528,7 +398,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(can_be_transposed){ + if (can_be_transposed) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -571,7 +441,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(can_be_transposed){ + if (can_be_transposed) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); @@ -589,7 +459,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - if(can_be_transposed){ + if (can_be_transposed) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); From 1f8e4c0e477d41dce3af75216b4218fce5fd7ac1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 15:02:00 -0500 Subject: [PATCH 13/18] make CI happy --- ggml/src/ggml-cuda/cpy.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 68cb030f0f21f..f47119dad03e5 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -385,7 +385,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = nb01 == ggml_element_size(src0) && src0->ne[3] == 1; + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); From e909afd9d71eea4db284119b9aecaaed68ab7319 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 07:47:57 -0500 Subject: [PATCH 14/18] remove headers not needed --- ggml/src/ggml-cuda/cpy.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index f47119dad03e5..8162ccc397a4b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,3 @@ -#include #include "cpy.cuh" #include "dequantize.cuh" #include "cpy-utils.cuh" From 3b8100c7bff1df8005f1e81f5cc69cfa213e4826 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 09:45:35 -0500 Subject: [PATCH 15/18] reduced bank conflicts for fp16 and bf16 --- ggml/src/ggml-cuda/cpy.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 8162ccc397a4b..ca26ea6ee7107 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -56,7 +56,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - __shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D]; + __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D]; #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -69,8 +69,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; - const int col = threadIdx.x ^ row; //swizzling to avoid bank conflicts - tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + const int col = (threadIdx.x*sizeof(float)/sizeof(T)) ^ row; //swizzling to avoid bank conflicts + T *tile2 = reinterpret_cast(tile[row]); + tile2[col] = src[imat*n + (y+j)*ne01 + x]; } } @@ -79,8 +80,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int #pragma unroll for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { - const int col = (threadIdx.y+j) ^ threadIdx.x; //swizzling to avoid bank conflicts - dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col]; + const int col = ((threadIdx.y+j)*sizeof(float)/sizeof(T)) ^ threadIdx.x; //swizzling to avoid bank conflicts + T *tile2 = reinterpret_cast(tile[threadIdx.x]); + dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; } } } From 51a25903658e31a29fcb5adcbd1aca28285c99fb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 09:57:07 -0500 Subject: [PATCH 16/18] add missing const* --- ggml/src/ggml-cuda/cpy.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index ca26ea6ee7107..9ac7a097f9d1a 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -81,7 +81,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { const int col = ((threadIdx.y+j)*sizeof(float)/sizeof(T)) ^ threadIdx.x; //swizzling to avoid bank conflicts - T *tile2 = reinterpret_cast(tile[threadIdx.x]); + const T *tile2 = reinterpret_cast(tile[threadIdx.x]); dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; } } From 2eb51172aba47d59e44bde402b3b6b9a5df80cbc Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 11:24:40 -0500 Subject: [PATCH 17/18] now bank conflicts free --- ggml/src/ggml-cuda/cpy.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 9ac7a097f9d1a..23183e5952f52 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -69,7 +69,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; - const int col = (threadIdx.x*sizeof(float)/sizeof(T)) ^ row; //swizzling to avoid bank conflicts + const int col = (threadIdx.x ^ row)*sizeof(float)/sizeof(T); //swizzling to avoid bank conflicts T *tile2 = reinterpret_cast(tile[row]); tile2[col] = src[imat*n + (y+j)*ne01 + x]; } @@ -80,7 +80,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int #pragma unroll for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { - const int col = ((threadIdx.y+j)*sizeof(float)/sizeof(T)) ^ threadIdx.x; //swizzling to avoid bank conflicts + const int col = ((threadIdx.y+j) ^ threadIdx.x)*sizeof(float)/sizeof(T); //swizzling to avoid bank conflicts const T *tile2 = reinterpret_cast(tile[threadIdx.x]); dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; } From bc95e58d70d3b388badd8df69392380c300649ba Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 11:57:29 -0500 Subject: [PATCH 18/18] use padding instead of swizzling --- ggml/src/ggml-cuda/cpy.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 23183e5952f52..1dba60eb143ef 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -56,7 +56,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D]; + __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -69,7 +69,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; - const int col = (threadIdx.x ^ row)*sizeof(float)/sizeof(T); //swizzling to avoid bank conflicts + const int col = threadIdx.x * sizeof(float)/sizeof(T); T *tile2 = reinterpret_cast(tile[row]); tile2[col] = src[imat*n + (y+j)*ne01 + x]; } @@ -80,7 +80,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int #pragma unroll for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { - const int col = ((threadIdx.y+j) ^ threadIdx.x)*sizeof(float)/sizeof(T); //swizzling to avoid bank conflicts + const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T); const T *tile2 = reinterpret_cast(tile[threadIdx.x]); dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; }