Skip to content

Commit 351cf56

Browse files
committed
Merge branch 'cuda-transpose-cpy' of github.com:bssrdf/llama.cpp into cuda-transpose-cpy
2 parents 8dbb4c7 + d49232c commit 351cf56

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -92,39 +92,24 @@ static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int
9292

9393
const T* src = reinterpret_cast<const T*>(cx);
9494
T* dst = reinterpret_cast<T*>(cdst);
95-
// nidx[0] inner most
96-
// nidx[1] middle
97-
// nidx[2] outer most
98-
// const int64_t nmat = ne / (ne00 * ne01);
99-
// const int64_t n = ne00 * ne01;
100-
// const int64_t ne00 = ne0[nidx[0]];
101-
// const int64_t ne01 = ne0[nidx[1]];
102-
// const int64_t ne02 = ne0[nidx[2]];
95+
10396
const int64_t n0 = ne00 * ne01;
104-
// const int64_t ne10 = ne1[0];
105-
// const int64_t ne11 = ne1[1];
106-
// const int64_t ne12 = ne1[2];
10797
const int64_t n1 = ne10 * ne11;
10898

10999
int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
110100
int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
111101
int z = blockIdx.z * CUDA_CPY_TILE_DIM;
112-
// int tx = blockIdx.x * CUDA_CPY_TILE_DIM[ntidx[0]] + threadIdx.x; // transpose block offset
113-
// int ty = blockIdx.y * CUDA_CPY_TILE_DIM[ntidx[1]] + threadIdx.y;
114-
// int tz = blockIdx.z * CUDA_CPY_TILE_DIM[ntidx[2]];
115102

116103
__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
117104

118105
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
119-
// for (int j = 0; j < CUDA_CPY_TILE_DIM[1]; ++j){
120106
if(x < ne00 && y < ne01 && z + k < ne02){
121107
// const int row = threadIdx.y+j;
122108
// const int col = threadIdx.x ^ row;
123109
const int row = threadIdx.y;
124110
const int col = threadIdx.x;
125111
tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
126112
}
127-
// }
128113
}
129114
__syncthreads();
130115

@@ -301,23 +286,15 @@ static void ggml_cpy_flt_cuda(
301286
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
302287

303288
if (coalesced){ //transpose
304-
// printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
305-
// printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12);
306-
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
307-
// printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
308289
// GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
309290
if( nb00 < nb02 && nb02 <= nb03 ) {
310-
// printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
311-
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
312291
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
313292
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
314293
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
315294
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
316295
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
317296
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
318297
} else{
319-
// printf("b %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
320-
// printf("d %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
321298
std::vector<std::tuple<int, int, int>> v;
322299
v.emplace_back(std::make_tuple(nb00, ne00, 0));
323300
v.emplace_back(std::make_tuple(nb01, ne01, 1));
@@ -333,15 +310,14 @@ static void ggml_cpy_flt_cuda(
333310
nidx[0] = std::get<2>(v[0]);
334311
nidx[1] = std::get<2>(v[1]);
335312
nidx[2] = std::get<2>(v[2]);
336-
// printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]);
337-
// printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new);
338313
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
339314
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);
340315

341-
dim3 dimGrid( (ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
342-
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
343-
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
316+
dim3 dimGrid((ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
317+
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
318+
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
344319
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);
320+
345321
if(zero_at == 2){
346322
if(one_at == 1){
347323
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
@@ -553,7 +529,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
553529
}
554530
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
555531
if(can_be_transposed){
556-
// printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
557532
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
558533
} else {
559534
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -597,7 +572,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
597572
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
598573
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
599574
if(can_be_transposed){
600-
// printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
601575
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
602576
} else {
603577
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -616,7 +590,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
616590
}
617591
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
618592
if(can_be_transposed){
619-
// printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
620593
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
621594
} else {
622595
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);

0 commit comments

Comments
 (0)