Skip to content

Commit c36b70b

Browse files
committed
tranpose copy more shapes
1 parent 3809645 commit c36b70b

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
6565
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
6666
if(imat >= nmat)
6767
break;
68-
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
68+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
6969
if(x < ne01 && y + j < ne00){
7070
const int row = threadIdx.y+j;
7171
const int col = threadIdx.x ^ row;
@@ -74,7 +74,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
7474
}
7575
__syncthreads();
7676

77-
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
77+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
7878
if(ty + j < ne01 && tx < ne00){
7979
const int col = (threadIdx.y+j) ^ threadIdx.x;
8080
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
@@ -305,8 +305,8 @@ static void ggml_cpy_flt_cuda(
305305
// printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12);
306306
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
307307
// printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
308-
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
309-
if(ne02 == 1) {
308+
// GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
309+
if( nb00 < nb02 && nb02 < nb03) {
310310
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
311311
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
312312
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
@@ -534,6 +534,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
534534
char * src1_ddc = (char *) src1->data;
535535

536536
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
537+
const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) &&
538+
(src0->ne[3] == 1 || (src0->nb[2] < src0->nb[3] && src0->nb[0] < src0->nb[2]));
537539

538540
if (src0->type == src1->type && contiguous_srcs) {
539541
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -546,7 +548,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
546548
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
547549
}
548550
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
549-
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
551+
if(can_be_transposed){
550552
// printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
551553
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
552554
} else {
@@ -590,7 +592,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
590592
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
591593
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
592594
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
593-
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
595+
if(can_be_transposed){
594596
// printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
595597
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
596598
} else {
@@ -609,7 +611,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
609611
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
610612
}
611613
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
612-
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
614+
if(can_be_transposed){
613615
// printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
614616
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
615617
} else {

tests/test-backend-ops.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6521,9 +6521,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
65216521
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
65226522
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
65236523
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
6524-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6525-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6526-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6524+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6525+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6526+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6527+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6528+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6529+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6530+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
65276531

65286532
test_cases.emplace_back(new test_cont());
65296533
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -7256,16 +7260,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
72567260
// test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
72577261

72587262

7259-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7260-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7261-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7262-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7263+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7264+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7265+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7266+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
72637267

7264-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7268+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72657269
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7266-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7267-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7268-
// test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7270+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7271+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7272+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72697273

72707274

72717275
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)