Skip to content

Commit 35daa02

Browse files
author
bssrdf
committed
reformulated to handle more complicated transpose cases
1 parent d3bdcf8 commit 35daa02

File tree

2 files changed

+200
-39
lines changed

2 files changed

+200
-39
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 200 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

10+
const int CUDA_CPY_TILE_DIM = 16;
11+
1012
template <cpy_kernel_t cpy_1>
1113
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
1214
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,43 +37,153 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3537
cpy_1(cx + x_offset, cdst + dst_offset);
3638
}
3739

38-
template <typename T>
39-
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
40+
// template <typename T>
41+
// static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
42+
// const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
43+
// const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
44+
// const int nb12, const int nb13) {
45+
46+
// const T* src = reinterpret_cast<const T*>(cx);
47+
// T* dst = reinterpret_cast<T*>(cdst);
48+
49+
// const int64_t nmat = ne / (ne00 * ne01);
50+
// const int64_t n = ne00 * ne01;
51+
52+
// int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
53+
// int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
54+
// int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset
55+
// int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y;
56+
57+
// __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
58+
59+
// for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
60+
61+
// const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
62+
// if(imat >= nmat)
63+
// break;
64+
// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
65+
// if(x < ne01 && y + j < ne00){
66+
// const int row = threadIdx.y+j;
67+
// const int col = threadIdx.x ^ row;
68+
// tile[row][col] = src[imat*n + (y+j)*ne01 + x];
69+
// }
70+
// }
71+
// __syncthreads();
72+
73+
// for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
74+
// if(ty + j < ne01 && tx < ne00){
75+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
76+
// dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
77+
// }
78+
// }
79+
// }
80+
// }
81+
82+
83+
template <typename T, const int zero_at, const int one_at>
84+
static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int ne,
4085
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4186
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
4287
const int nb12, const int nb13) {
4388

4489
const T* src = reinterpret_cast<const T*>(cx);
4590
T* dst = reinterpret_cast<T*>(cdst);
46-
47-
const int64_t nmat = ne / (ne00 * ne01);
48-
const int64_t n = ne00 * ne01;
91+
// nidx[0] inner most
92+
// nidx[1] middle
93+
// nidx[2] outer most
94+
// const int64_t nmat = ne / (ne00 * ne01);
95+
// const int64_t n = ne00 * ne01;
96+
// const int64_t ne00 = ne0[nidx[0]];
97+
// const int64_t ne01 = ne0[nidx[1]];
98+
// const int64_t ne02 = ne0[nidx[2]];
99+
const int64_t n0 = ne00 * ne01;
100+
// const int64_t ne10 = ne1[0];
101+
// const int64_t ne11 = ne1[1];
102+
// const int64_t ne12 = ne1[2];
103+
const int64_t n1 = ne10 * ne11;
49104

50105
int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
51106
int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
52-
int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset
53-
int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y;
54-
55-
__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
56-
57-
for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
58-
59-
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
60-
if(imat >= nmat)
61-
break;
62-
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
63-
if(x < ne01 && y + j < ne00){
64-
const int row = threadIdx.y+j;
65-
const int col = threadIdx.x ^ row;
66-
tile[row][col] = src[imat*n + (y+j)*ne01 + x];
107+
int z = blockIdx.z * CUDA_CPY_TILE_DIM;
108+
// int tx = blockIdx.x * CUDA_CPY_TILE_DIM[ntidx[0]] + threadIdx.x; // transpose block offset
109+
// int ty = blockIdx.y * CUDA_CPY_TILE_DIM[ntidx[1]] + threadIdx.y;
110+
// int tz = blockIdx.z * CUDA_CPY_TILE_DIM[ntidx[2]];
111+
112+
__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
113+
114+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
115+
// for (int j = 0; j < CUDA_CPY_TILE_DIM[1]; ++j){
116+
if(x < ne00 && y < ne01 && z + k < ne02){
117+
// const int row = threadIdx.y+j;
118+
// const int col = threadIdx.x ^ row;
119+
const int row = threadIdx.y;
120+
const int col = threadIdx.x;
121+
tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
122+
}
123+
// }
124+
}
125+
__syncthreads();
126+
127+
if(zero_at == 2){
128+
int tx = blockIdx.z * CUDA_CPY_TILE_DIM;
129+
if(one_at == 0){
130+
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
131+
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
132+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
133+
// const int row = threadIdx.y;
134+
// const int col = threadIdx.x;
135+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
136+
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
137+
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][k][threadIdx.y];
138+
}
139+
}
140+
} else{ // one at 1
141+
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
142+
int ty = blockIdx.y * CUDA_CPY_TILE_DIM;
143+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
144+
// const int row = threadIdx.y;
145+
// const int col = threadIdx.x;
146+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
147+
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
148+
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][threadIdx.y][k];
149+
}
67150
}
68151
}
69-
__syncthreads();
70-
71-
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
72-
if(ty + j < ne01 && tx < ne00){
73-
const int col = (threadIdx.y+j) ^ threadIdx.x;
74-
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
152+
} else if(zero_at == 1){
153+
int tx = blockIdx.y * CUDA_CPY_TILE_DIM;
154+
if(one_at == 0){
155+
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
156+
int tz = blockIdx.z * CUDA_CPY_TILE_DIM;
157+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
158+
// const int row = threadIdx.y;
159+
// const int col = threadIdx.x;
160+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
161+
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
162+
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[k][threadIdx.x][threadIdx.y];
163+
}
164+
}
165+
} else { // one at 2
166+
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
167+
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
168+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
169+
// const int row = threadIdx.y;
170+
// const int col = threadIdx.x;
171+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
172+
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
173+
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][threadIdx.x][k];
174+
}
175+
}
176+
}
177+
} else{ // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous
178+
int tx = blockIdx.x * CUDA_CPY_TILE_DIM;
179+
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
180+
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
181+
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
182+
// const int row = threadIdx.y;
183+
// const int col = threadIdx.x;
184+
// const int col = (threadIdx.y+j) ^ threadIdx.x;
185+
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
186+
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][k][threadIdx.x];
75187
}
76188
}
77189
}
@@ -178,18 +290,67 @@ cudaStream_t stream) {
178290
(cx, cdst, ne);
179291
}
180292

181-
template<typename src_t, typename dst_t, bool transpose = false>
293+
template<typename src_t, typename dst_t, bool coalesced = false>
182294
static void ggml_cpy_flt_cuda(
183295
const char * cx, char * cdst, const int ne,
184296
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
185297
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
186298

187-
if (transpose){ //transpose
188-
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
189-
(ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
190-
(ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM );
191-
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1);
192-
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
299+
if (coalesced){ //transpose
300+
// printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
301+
// printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12);
302+
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
303+
// printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
304+
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
305+
std::vector<std::tuple<int, int, int>> v;
306+
v.emplace_back(std::make_tuple(nb00, ne00, 0));
307+
v.emplace_back(std::make_tuple(nb01, ne01, 1));
308+
v.emplace_back(std::make_tuple(nb02, ne02, 2));
309+
std::sort(v.begin(), v.end(),
310+
[](auto &a, auto &b) {
311+
return std::get<0>(a) < std::get<0>(b);
312+
});
313+
const int ne0_new = std::get<1>(v[0]);
314+
const int ne1_new = std::get<1>(v[1]);
315+
const int ne2_new = std::get<1>(v[2]);
316+
int nidx[3];
317+
nidx[0] = std::get<2>(v[0]);
318+
nidx[1] = std::get<2>(v[1]);
319+
nidx[2] = std::get<2>(v[2]);
320+
// printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]);
321+
// printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new);
322+
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
323+
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);
324+
325+
dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
326+
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
327+
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
328+
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);
329+
if(zero_at == 2){
330+
if(one_at == 1){
331+
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
332+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
333+
nb10, nb11, nb12, nb13);
334+
}else{
335+
cpy_flt_coalesced<dst_t, 2, 0><<<dimGrid, dimBlock, 0, stream>>>(
336+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
337+
nb10, nb11, nb12, nb13);
338+
}
339+
} else if(zero_at == 1){
340+
if(one_at == 2){
341+
cpy_flt_coalesced<dst_t, 1, 2><<<dimGrid, dimBlock, 0, stream>>>(
342+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
343+
nb10, nb11, nb12, nb13);
344+
}else{
345+
cpy_flt_coalesced<dst_t, 1, 0><<<dimGrid, dimBlock, 0, stream>>>(
346+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
347+
nb10, nb11, nb12, nb13);
348+
}
349+
} else{
350+
cpy_flt_coalesced<dst_t, 0, 2><<<dimGrid, dimBlock, 0, stream>>>(
351+
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
352+
nb10, nb11, nb12, nb13);
353+
}
193354
} else{ // other
194355
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
195356
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
@@ -372,7 +533,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372533
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
373534
}
374535
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375-
if(src0->op == GGML_OP_TRANSPOSE){
536+
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
537+
// printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
376538
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
377539
} else {
378540
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -415,7 +577,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
415577
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
416578
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
417579
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
418-
if(src0->op == GGML_OP_TRANSPOSE){
580+
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
581+
// printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
419582
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
420583
} else {
421584
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -433,7 +596,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
433596
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
434597
}
435598
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
436-
if(src0->op == GGML_OP_TRANSPOSE){
599+
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) && src0->ne[3] == 1){
600+
// printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
437601
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438602
} else {
439603
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);

ggml/src/ggml-cuda/cpy.cuh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#include "common.cuh"
22

33
#define CUDA_CPY_BLOCK_SIZE 64
4-
#define CUDA_CPY_TILE_DIM 32
5-
#define CUDA_CPY_BLOCK_ROWS 8
6-
#define CUDA_CPY_BLOCK_NM 8
74

85
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
96

0 commit comments

Comments
 (0)