Skip to content

Commit d2ec251

Browse files
author
bssrdf
committed
bring back 2D transpose for higher performance
1 parent 29387ce commit d2ec251

File tree

2 files changed

+113
-100
lines changed

2 files changed

+113
-100
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 98 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

1010
const int CUDA_CPY_TILE_DIM = 16;
11+
const int CUDA_CPY_TILE_DIM_2D = 32;
12+
const int CUDA_CPY_BLOCK_NM = 8;
13+
const int CUDA_CPY_BLOCK_ROWS = 8;
1114

1215
template <cpy_kernel_t cpy_1>
1316
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
@@ -37,47 +40,47 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3740
cpy_1(cx + x_offset, cdst + dst_offset);
3841
}
3942

40-
// template <typename T>
41-
// static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
42-
// const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
43-
// const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
44-
// const int nb12, const int nb13) {
45-
46-
// const T* src = reinterpret_cast<const T*>(cx);
47-
// T* dst = reinterpret_cast<T*>(cdst);
48-
49-
// const int64_t nmat = ne / (ne00 * ne01);
50-
// const int64_t n = ne00 * ne01;
51-
52-
// int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
53-
// int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
54-
// int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset
55-
// int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y;
56-
57-
// __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
58-
59-
// for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
60-
61-
// const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
62-
// if(imat >= nmat)
63-
// break;
64-
// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
65-
// if(x < ne01 && y + j < ne00){
66-
// const int row = threadIdx.y+j;
67-
// const int col = threadIdx.x ^ row;
68-
// tile[row][col] = src[imat*n + (y+j)*ne01 + x];
69-
// }
70-
// }
71-
// __syncthreads();
72-
73-
// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
74-
// if(ty + j < ne01 && tx < ne00){
75-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
76-
// dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
77-
// }
78-
// }
79-
// }
80-
// }
43+
template <typename T>
44+
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
45+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
46+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
47+
const int nb12, const int nb13) {
48+
49+
const T* src = reinterpret_cast<const T*>(cx);
50+
T* dst = reinterpret_cast<T*>(cdst);
51+
52+
const int64_t nmat = ne / (ne00 * ne01);
53+
const int64_t n = ne00 * ne01;
54+
55+
int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
56+
int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
57+
int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
58+
int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
59+
60+
__shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];
61+
62+
for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
63+
64+
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
65+
if(imat >= nmat)
66+
break;
67+
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
68+
if(x < ne01 && y + j < ne00){
69+
const int row = threadIdx.y+j;
70+
const int col = threadIdx.x ^ row;
71+
tile[row][col] = src[imat*n + (y+j)*ne01 + x];
72+
}
73+
}
74+
__syncthreads();
75+
76+
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
77+
if(ty + j < ne01 && tx < ne00){
78+
const int col = (threadIdx.y+j) ^ threadIdx.x;
79+
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
80+
}
81+
}
82+
}
83+
}
8184

8285

8386
template <typename T, const int zero_at, const int one_at>
@@ -302,54 +305,63 @@ static void ggml_cpy_flt_cuda(
302305
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
303306
// printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
304307
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
305-
std::vector<std::tuple<int, int, int>> v;
306-
v.emplace_back(std::make_tuple(nb00, ne00, 0));
307-
v.emplace_back(std::make_tuple(nb01, ne01, 1));
308-
v.emplace_back(std::make_tuple(nb02, ne02, 2));
309-
std::sort(v.begin(), v.end(),
310-
[](auto &a, auto &b) {
311-
return std::get<0>(a) < std::get<0>(b);
312-
});
313-
const int ne0_new = std::get<1>(v[0]);
314-
const int ne1_new = std::get<1>(v[1]);
315-
const int ne2_new = std::get<1>(v[2]);
316-
int nidx[3];
317-
nidx[0] = std::get<2>(v[0]);
318-
nidx[1] = std::get<2>(v[1]);
319-
nidx[2] = std::get<2>(v[2]);
320-
// printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]);
321-
// printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new);
322-
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
323-
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);
324-
325-
dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
326-
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
327-
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
328-
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);
329-
if(zero_at == 2){
330-
if(one_at == 1){
331-
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
332-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
333-
nb10, nb11, nb12, nb13);
334-
}else{
335-
cpy_flt_coalesced<dst_t, 2, 0><<<dimGrid, dimBlock, 0, stream>>>(
336-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
337-
nb10, nb11, nb12, nb13);
338-
}
339-
} else if(zero_at == 1){
340-
if(one_at == 2){
341-
cpy_flt_coalesced<dst_t, 1, 2><<<dimGrid, dimBlock, 0, stream>>>(
342-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
343-
nb10, nb11, nb12, nb13);
344-
}else{
345-
cpy_flt_coalesced<dst_t, 1, 0><<<dimGrid, dimBlock, 0, stream>>>(
308+
if(ne02 == 1) {
309+
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
310+
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
311+
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
312+
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
313+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
314+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
315+
} else{
316+
std::vector<std::tuple<int, int, int>> v;
317+
v.emplace_back(std::make_tuple(nb00, ne00, 0));
318+
v.emplace_back(std::make_tuple(nb01, ne01, 1));
319+
v.emplace_back(std::make_tuple(nb02, ne02, 2));
320+
std::sort(v.begin(), v.end(),
321+
[](auto &a, auto &b) {
322+
return std::get<0>(a) < std::get<0>(b);
323+
});
324+
const int ne0_new = std::get<1>(v[0]);
325+
const int ne1_new = std::get<1>(v[1]);
326+
const int ne2_new = std::get<1>(v[2]);
327+
int nidx[3];
328+
nidx[0] = std::get<2>(v[0]);
329+
nidx[1] = std::get<2>(v[1]);
330+
nidx[2] = std::get<2>(v[2]);
331+
// printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]);
332+
// printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new);
333+
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
334+
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);
335+
336+
dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
337+
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
338+
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
339+
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);
340+
if(zero_at == 2){
341+
if(one_at == 1){
342+
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
343+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
344+
nb10, nb11, nb12, nb13);
345+
}else{
346+
cpy_flt_coalesced<dst_t, 2, 0><<<dimGrid, dimBlock, 0, stream>>>(
347+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
348+
nb10, nb11, nb12, nb13);
349+
}
350+
} else if(zero_at == 1){
351+
if(one_at == 2){
352+
cpy_flt_coalesced<dst_t, 1, 2><<<dimGrid, dimBlock, 0, stream>>>(
353+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
354+
nb10, nb11, nb12, nb13);
355+
}else{
356+
cpy_flt_coalesced<dst_t, 1, 0><<<dimGrid, dimBlock, 0, stream>>>(
357+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
358+
nb10, nb11, nb12, nb13);
359+
}
360+
} else{
361+
cpy_flt_coalesced<dst_t, 0, 2><<<dimGrid, dimBlock, 0, stream>>>(
346362
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
347363
nb10, nb11, nb12, nb13);
348364
}
349-
} else{
350-
cpy_flt_coalesced<dst_t, 0, 2><<<dimGrid, dimBlock, 0, stream>>>(
351-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
352-
nb10, nb11, nb12, nb13);
353365
}
354366
} else{ // other
355367
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;

tests/test-backend-ops.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,7 +2512,7 @@ struct test_cpy : public test_case {
25122512
bool _src_transpose;
25132513

25142514
std::string vars() override {
2515-
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
2515+
return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
25162516
}
25172517

25182518
double max_nmse_err() override {
@@ -7249,22 +7249,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
72497249
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
72507250
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
72517251

7252-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
7253-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
7254-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
7255-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
7256-
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
7252+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
7253+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
7254+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
7255+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
7256+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
72577257

72587258

7259-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7260-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7261-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7262-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7259+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7260+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7261+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7262+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
72637263

7264-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7265-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7266-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7267-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7264+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7265+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7266+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7267+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7268+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72687269

72697270

72707271
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)