77
88typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
99
10+ const int CUDA_CPY_TILE_DIM = 16 ;
11+
1012template <cpy_kernel_t cpy_1>
1113static __global__ void cpy_flt (const char * cx, char * cdst, const int ne,
1214 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,43 +37,153 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3537 cpy_1 (cx + x_offset, cdst + dst_offset);
3638}
3739
38- template <typename T>
39- static __global__ void cpy_flt_transpose (const char * cx, char * cdst, const int ne,
40+ // template <typename T>
41+ // static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
42+ // const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
43+ // const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
44+ // const int nb12, const int nb13) {
45+
46+ // const T* src = reinterpret_cast<const T*>(cx);
47+ // T* dst = reinterpret_cast<T*>(cdst);
48+
49+ // const int64_t nmat = ne / (ne00 * ne01);
50+ // const int64_t n = ne00 * ne01;
51+
52+ // int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
53+ // int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
54+ // int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset
55+ // int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y;
56+
57+ // __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
58+
59+ // for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){
60+
61+ // const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
62+ // if(imat >= nmat)
63+ // break;
64+ // for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
65+ // if(x < ne01 && y + j < ne00){
66+ // const int row = threadIdx.y+j;
67+ // const int col = threadIdx.x ^ row;
68+ // tile[row][col] = src[imat*n + (y+j)*ne01 + x];
69+ // }
70+ // }
71+ // __syncthreads();
72+
73+ // for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
74+ // if(ty + j < ne01 && tx < ne00){
75+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
76+ // dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
77+ // }
78+ // }
79+ // }
80+ // }
81+
82+
83+ template <typename T, const int zero_at, const int one_at>
84+ static __global__ void cpy_flt_coalesced (const char * cx, char * cdst, const int ne,
4085 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4186 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
4287 const int nb12, const int nb13) {
4388
4489 const T* src = reinterpret_cast <const T*>(cx);
4590 T* dst = reinterpret_cast <T*>(cdst);
46-
47- const int64_t nmat = ne / (ne00 * ne01);
48- const int64_t n = ne00 * ne01;
91+ // nidx[0] inner most
92+ // nidx[1] middle
93+ // nidx[2] outer most
94+ // const int64_t nmat = ne / (ne00 * ne01);
95+ // const int64_t n = ne00 * ne01;
96+ // const int64_t ne00 = ne0[nidx[0]];
97+ // const int64_t ne01 = ne0[nidx[1]];
98+ // const int64_t ne02 = ne0[nidx[2]];
99+ const int64_t n0 = ne00 * ne01;
100+ // const int64_t ne10 = ne1[0];
101+ // const int64_t ne11 = ne1[1];
102+ // const int64_t ne12 = ne1[2];
103+ const int64_t n1 = ne10 * ne11;
49104
50105 int x = blockIdx .x * CUDA_CPY_TILE_DIM + threadIdx .x ;
51106 int y = blockIdx .y * CUDA_CPY_TILE_DIM + threadIdx .y ;
52- int tx = blockIdx .y * CUDA_CPY_TILE_DIM + threadIdx .x ; // transpose block offset
53- int ty = blockIdx .x * CUDA_CPY_TILE_DIM + threadIdx .y ;
54-
55- __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
56-
57- for (int i = 0 ; i < CUDA_CPY_BLOCK_NM; ++i){
58-
59- const unsigned int imat = blockIdx .z * CUDA_CPY_BLOCK_NM + i;
60- if (imat >= nmat)
61- break ;
62- for (int j = 0 ; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
63- if (x < ne01 && y + j < ne00){
64- const int row = threadIdx .y +j;
65- const int col = threadIdx .x ^ row;
66- tile[row][col] = src[imat*n + (y+j)*ne01 + x];
107+ int z = blockIdx .z * CUDA_CPY_TILE_DIM;
108+ // int tx = blockIdx.x * CUDA_CPY_TILE_DIM[ntidx[0]] + threadIdx.x; // transpose block offset
109+ // int ty = blockIdx.y * CUDA_CPY_TILE_DIM[ntidx[1]] + threadIdx.y;
110+ // int tz = blockIdx.z * CUDA_CPY_TILE_DIM[ntidx[2]];
111+
112+ __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
113+
114+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
115+ // for (int j = 0; j < CUDA_CPY_TILE_DIM[1]; ++j){
116+ if (x < ne00 && y < ne01 && z + k < ne02){
117+ // const int row = threadIdx.y+j;
118+ // const int col = threadIdx.x ^ row;
119+ const int row = threadIdx .y ;
120+ const int col = threadIdx .x ;
121+ tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
122+ }
123+ // }
124+ }
125+ __syncthreads ();
126+
127+ if (zero_at == 2 ){
128+ int tx = blockIdx .z * CUDA_CPY_TILE_DIM;
129+ if (one_at == 0 ){
130+ int ty = blockIdx .x * CUDA_CPY_TILE_DIM;
131+ int tz = blockIdx .y * CUDA_CPY_TILE_DIM;
132+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
133+ // const int row = threadIdx.y;
134+ // const int col = threadIdx.x;
135+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
136+ if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
137+ dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .x ][k][threadIdx .y ];
138+ }
139+ }
140+ } else { // one at 1
141+ int tz = blockIdx .x * CUDA_CPY_TILE_DIM;
142+ int ty = blockIdx .y * CUDA_CPY_TILE_DIM;
143+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
144+ // const int row = threadIdx.y;
145+ // const int col = threadIdx.x;
146+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
147+ if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
148+ dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .x ][threadIdx .y ][k];
149+ }
67150 }
68151 }
69- __syncthreads ();
70-
71- for (int j = 0 ; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
72- if (ty + j < ne01 && tx < ne00){
73- const int col = (threadIdx .y +j) ^ threadIdx .x ;
74- dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx .x ][col];
152+ } else if (zero_at == 1 ){
153+ int tx = blockIdx .y * CUDA_CPY_TILE_DIM;
154+ if (one_at == 0 ){
155+ int ty = blockIdx .x * CUDA_CPY_TILE_DIM;
156+ int tz = blockIdx .z * CUDA_CPY_TILE_DIM;
157+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
158+ // const int row = threadIdx.y;
159+ // const int col = threadIdx.x;
160+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
161+ if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
162+ dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[k][threadIdx .x ][threadIdx .y ];
163+ }
164+ }
165+ } else { // one at 2
166+ int ty = blockIdx .z * CUDA_CPY_TILE_DIM;
167+ int tz = blockIdx .x * CUDA_CPY_TILE_DIM;
168+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
169+ // const int row = threadIdx.y;
170+ // const int col = threadIdx.x;
171+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
172+ if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
173+ dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .y ][threadIdx .x ][k];
174+ }
175+ }
176+ }
177+ } else { // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous
178+ int tx = blockIdx .x * CUDA_CPY_TILE_DIM;
179+ int ty = blockIdx .z * CUDA_CPY_TILE_DIM;
180+ int tz = blockIdx .y * CUDA_CPY_TILE_DIM;
181+ for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
182+ // const int row = threadIdx.y;
183+ // const int col = threadIdx.x;
184+ // const int col = (threadIdx.y+j) ^ threadIdx.x;
185+ if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
186+ dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .y ][k][threadIdx .x ];
75187 }
76188 }
77189 }
@@ -178,18 +290,67 @@ cudaStream_t stream) {
178290 (cx, cdst, ne);
179291}
180292
181- template <typename src_t , typename dst_t , bool transpose = false >
293+ template <typename src_t , typename dst_t , bool coalesced = false >
182294static void ggml_cpy_flt_cuda (
183295 const char * cx, char * cdst, const int ne,
184296 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
185297 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
186298
187- if (transpose){ // transpose
188- dim3 dimGrid ( (ne01 + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
189- (ne00 + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
190- (ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1 ) / CUDA_CPY_BLOCK_NM );
191- dim3 dimBlock (CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1 );
192- cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
299+ if (coalesced){ // transpose
300+ // printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
301+ // printf("b %zu, %zu, %zu, %zu, \n", ne, ne10, ne11, ne12);
302+ // printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
303+ // printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
304+ GGML_ASSERT (ne == ne00*ne01*ne02); // ne[3] is 1 assumed
305+ std::vector<std::tuple<int , int , int >> v;
306+ v.emplace_back (std::make_tuple (nb00, ne00, 0 ));
307+ v.emplace_back (std::make_tuple (nb01, ne01, 1 ));
308+ v.emplace_back (std::make_tuple (nb02, ne02, 2 ));
309+ std::sort (v.begin (), v.end (),
310+ [](auto &a, auto &b) {
311+ return std::get<0 >(a) < std::get<0 >(b);
312+ });
313+ const int ne0_new = std::get<1 >(v[0 ]);
314+ const int ne1_new = std::get<1 >(v[1 ]);
315+ const int ne2_new = std::get<1 >(v[2 ]);
316+ int nidx[3 ];
317+ nidx[0 ] = std::get<2 >(v[0 ]);
318+ nidx[1 ] = std::get<2 >(v[1 ]);
319+ nidx[2 ] = std::get<2 >(v[2 ]);
320+ // printf(" nidx: [%d, %d, %d] \n", nidx[0], nidx[1], nidx[2]);
321+ // printf(" ne_new: [%d, %d, %d] \n", ne0_new, ne1_new, ne2_new);
322+ const int zero_at = nidx[2 ] == 0 ? 2 : (nidx[1 ] == 0 ? 1 : 0 );
323+ const int one_at = nidx[2 ] == 1 ? 2 : (nidx[1 ] == 1 ? 1 : 0 );
324+
325+ dim3 dimGrid ( (ne0_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
326+ (ne1_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
327+ (ne2_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM);
328+ dim3 dimBlock (CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1 );
329+ if (zero_at == 2 ){
330+ if (one_at == 1 ){
331+ cpy_flt_coalesced<dst_t , 2 , 1 ><<<dimGrid, dimBlock, 0 , stream>>> (
332+ cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
333+ nb10, nb11, nb12, nb13);
334+ }else {
335+ cpy_flt_coalesced<dst_t , 2 , 0 ><<<dimGrid, dimBlock, 0 , stream>>> (
336+ cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
337+ nb10, nb11, nb12, nb13);
338+ }
339+ } else if (zero_at == 1 ){
340+ if (one_at == 2 ){
341+ cpy_flt_coalesced<dst_t , 1 , 2 ><<<dimGrid, dimBlock, 0 , stream>>> (
342+ cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
343+ nb10, nb11, nb12, nb13);
344+ }else {
345+ cpy_flt_coalesced<dst_t , 1 , 0 ><<<dimGrid, dimBlock, 0 , stream>>> (
346+ cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
347+ nb10, nb11, nb12, nb13);
348+ }
349+ } else {
350+ cpy_flt_coalesced<dst_t , 0 , 2 ><<<dimGrid, dimBlock, 0 , stream>>> (
351+ cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
352+ nb10, nb11, nb12, nb13);
353+ }
193354 } else { // other
194355 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
195356 cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
@@ -372,7 +533,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372533 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
373534 }
374535 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375- if (src0->op == GGML_OP_TRANSPOSE){
536+ if (src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous (src0) && src0->ne [3 ] == 1 ){
537+ // printf("A %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
376538 ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
377539 } else {
378540 ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -415,7 +577,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
415577 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
416578 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
417579 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
418- if (src0->op == GGML_OP_TRANSPOSE){
580+ if (src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous (src0) && src0->ne [3 ] == 1 ){
581+ // printf("B %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
419582 ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
420583 } else {
421584 ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -433,7 +596,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
433596 ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
434597 }
435598 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
436- if (src0->op == GGML_OP_TRANSPOSE){
599+ if (src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous (src0) && src0->ne [3 ] == 1 ){
600+ // printf("C %s, %s \n", ggml_op_desc(src0), ggml_op_desc(src1));
437601 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438602 } else {
439603 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
0 commit comments