Skip to content

Commit 28e5cf6

Browse files
author
bssrdf
committed
keep only the kernel for true tranposed case; updated with review suggestions
1 parent 351cf56 commit 28e5cf6

File tree

1 file changed

+45
-175
lines changed

1 file changed

+45
-175
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 45 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1010

11-
const int CUDA_CPY_TILE_DIM = 16;
12-
const int CUDA_CPY_TILE_DIM_2D = 32;
13-
const int CUDA_CPY_BLOCK_NM = 8;
14-
const int CUDA_CPY_BLOCK_ROWS = 8;
11+
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
12+
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
13+
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
1514

1615
template <cpy_kernel_t cpy_1>
1716
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
@@ -53,131 +52,41 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
5352
const int64_t nmat = ne / (ne00 * ne01);
5453
const int64_t n = ne00 * ne01;
5554

56-
int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
57-
int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
58-
int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
59-
int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
55+
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
56+
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
57+
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
58+
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
6059

6160
__shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];
6261

63-
for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
62+
#pragma unroll
63+
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
6464

6565
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
66-
if(imat >= nmat)
66+
if (imat >= nmat)
6767
break;
68-
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
68+
69+
#pragma unroll
70+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
6971
if(x < ne01 && y + j < ne00){
7072
const int row = threadIdx.y+j;
71-
const int col = threadIdx.x ^ row;
73+
const int col = threadIdx.x ^ row; //swizzling to avoid bank conflicts
7274
tile[row][col] = src[imat*n + (y+j)*ne01 + x];
7375
}
7476
}
77+
7578
__syncthreads();
7679

77-
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
78-
if(ty + j < ne01 && tx < ne00){
79-
const int col = (threadIdx.y+j) ^ threadIdx.x;
80+
#pragma unroll
81+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82+
if (ty + j < ne01 && tx < ne00) {
83+
const int col = (threadIdx.y+j) ^ threadIdx.x; //swizzling to avoid bank conflicts
8084
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
8185
}
8286
}
8387
}
8488
}
8589

86-
87-
template <typename T, const int zero_at, const int one_at>
88-
static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int ne,
89-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
90-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
91-
const int nb12, const int nb13) {
92-
93-
const T* src = reinterpret_cast<const T*>(cx);
94-
T* dst = reinterpret_cast<T*>(cdst);
95-
96-
const int64_t n0 = ne00 * ne01;
97-
const int64_t n1 = ne10 * ne11;
98-
99-
int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
100-
int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
101-
int z = blockIdx.z * CUDA_CPY_TILE_DIM;
102-
103-
__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
104-
105-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
106-
if(x < ne00 && y < ne01 && z + k < ne02){
107-
// const int row = threadIdx.y+j;
108-
// const int col = threadIdx.x ^ row;
109-
const int row = threadIdx.y;
110-
const int col = threadIdx.x;
111-
tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
112-
}
113-
}
114-
__syncthreads();
115-
116-
if(zero_at == 2){
117-
int tx = blockIdx.z * CUDA_CPY_TILE_DIM;
118-
if(one_at == 0){
119-
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
120-
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
121-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
122-
// const int row = threadIdx.y;
123-
// const int col = threadIdx.x;
124-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
125-
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
126-
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][k][threadIdx.y];
127-
}
128-
}
129-
} else{ // one at 1
130-
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
131-
int ty = blockIdx.y * CUDA_CPY_TILE_DIM;
132-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
133-
// const int row = threadIdx.y;
134-
// const int col = threadIdx.x;
135-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
136-
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
137-
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][threadIdx.y][k];
138-
}
139-
}
140-
}
141-
} else if(zero_at == 1){
142-
int tx = blockIdx.y * CUDA_CPY_TILE_DIM;
143-
if(one_at == 0){
144-
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
145-
int tz = blockIdx.z * CUDA_CPY_TILE_DIM;
146-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
147-
// const int row = threadIdx.y;
148-
// const int col = threadIdx.x;
149-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
150-
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
151-
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[k][threadIdx.x][threadIdx.y];
152-
}
153-
}
154-
} else { // one at 2
155-
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
156-
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
157-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
158-
// const int row = threadIdx.y;
159-
// const int col = threadIdx.x;
160-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
161-
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
162-
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][threadIdx.x][k];
163-
}
164-
}
165-
}
166-
} else{ // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous
167-
int tx = blockIdx.x * CUDA_CPY_TILE_DIM;
168-
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
169-
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
170-
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
171-
// const int row = threadIdx.y;
172-
// const int col = threadIdx.x;
173-
// const int col = (threadIdx.y+j) ^ threadIdx.x;
174-
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
175-
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][k][threadIdx.x];
176-
}
177-
}
178-
}
179-
}
180-
18190
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
18291
float * cdstf = (float *)(cdsti);
18392

@@ -279,72 +188,34 @@ cudaStream_t stream) {
279188
(cx, cdst, ne);
280189
}
281190

282-
template<typename src_t, typename dst_t, bool coalesced = false>
191+
template<typename src_t, typename dst_t, bool transposed = false>
283192
static void ggml_cpy_flt_cuda(
284193
const char * cx, char * cdst, const int ne,
285194
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
286195
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
287196

288-
if (coalesced){ //transpose
289-
// GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
290-
if( nb00 < nb02 && nb02 <= nb03 ) {
291-
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
292-
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
293-
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
294-
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
295-
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
296-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
297-
} else{
298-
std::vector<std::tuple<int, int, int>> v;
299-
v.emplace_back(std::make_tuple(nb00, ne00, 0));
300-
v.emplace_back(std::make_tuple(nb01, ne01, 1));
301-
v.emplace_back(std::make_tuple(nb02, ne02, 2));
302-
std::sort(v.begin(), v.end(),
303-
[](auto &a, auto &b) {
304-
return std::get<0>(a) < std::get<0>(b);
305-
});
306-
const int ne0_new = std::get<1>(v[0]);
307-
const int ne1_new = std::get<1>(v[1]);
308-
const int ne2_new = std::get<1>(v[2]);
309-
int nidx[3];
310-
nidx[0] = std::get<2>(v[0]);
311-
nidx[1] = std::get<2>(v[1]);
312-
nidx[2] = std::get<2>(v[2]);
313-
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
314-
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);
315-
316-
dim3 dimGrid((ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
317-
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
318-
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
319-
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);
320-
321-
if(zero_at == 2){
322-
if(one_at == 1){
323-
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
324-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
325-
nb10, nb11, nb12, nb13);
326-
}else{
327-
cpy_flt_coalesced<dst_t, 2, 0><<<dimGrid, dimBlock, 0, stream>>>(
328-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
329-
nb10, nb11, nb12, nb13);
330-
}
331-
} else if(zero_at == 1){
332-
if(one_at == 2){
333-
cpy_flt_coalesced<dst_t, 1, 2><<<dimGrid, dimBlock, 0, stream>>>(
334-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
335-
nb10, nb11, nb12, nb13);
336-
}else{
337-
cpy_flt_coalesced<dst_t, 1, 0><<<dimGrid, dimBlock, 0, stream>>>(
338-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
339-
nb10, nb11, nb12, nb13);
340-
}
341-
} else{
342-
cpy_flt_coalesced<dst_t, 0, 2><<<dimGrid, dimBlock, 0, stream>>>(
343-
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
344-
nb10, nb11, nb12, nb13);
345-
}
197+
if (transposed) {
198+
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
199+
int ne00n, ne01n, ne02n;
200+
if (nb00 < nb02) {
201+
ne00n = ne00;
202+
ne01n = ne01;
203+
ne02n = ne02;
204+
} else if (nb00 > nb02) {
205+
ne00n = ne00;
206+
ne01n = ne01*ne02;
207+
ne02n = 1;
208+
} else {
209+
GGML_ASSERT(false);
346210
}
347-
} else{ // other
211+
212+
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
213+
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
214+
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
215+
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
216+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
217+
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
218+
} else {
348219
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
349220
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
350221
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -514,8 +385,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
514385
char * src1_ddc = (char *) src1->data;
515386

516387
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
517-
const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) &&
518-
(src0->ne[3] == 1 || (src0->nb[2] <= src0->nb[3] && src0->nb[0] < src0->nb[2]));
388+
const bool can_be_transposed = nb01 == ggml_element_size(src0) && src0->ne[3] == 1;
519389

520390
if (src0->type == src1->type && contiguous_srcs) {
521391
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -528,7 +398,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
528398
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
529399
}
530400
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
531-
if(can_be_transposed){
401+
if (can_be_transposed) {
532402
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
533403
} else {
534404
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -571,7 +441,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
571441
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
572442
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
573443
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
574-
if(can_be_transposed){
444+
if (can_be_transposed) {
575445
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
576446
} else {
577447
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -589,7 +459,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
589459
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
590460
}
591461
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
592-
if(can_be_transposed){
462+
if (can_be_transposed) {
593463
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
594464
} else {
595465
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);

0 commit comments

Comments
 (0)