88
99typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
1010
11- const int CUDA_CPY_TILE_DIM = 16 ;
12- const int CUDA_CPY_TILE_DIM_2D = 32 ;
13- const int CUDA_CPY_BLOCK_NM = 8 ;
14- const int CUDA_CPY_BLOCK_ROWS = 8 ;
11+ const int CUDA_CPY_TILE_DIM_2D = 32 ; // 2D tile dimension for transposed blocks
12+ const int CUDA_CPY_BLOCK_NM = 8 ; // block size of 3rd dimension if available
13+ const int CUDA_CPY_BLOCK_ROWS = 8 ; // block dimension for marching through rows
1514
1615template <cpy_kernel_t cpy_1>
1716static __global__ void cpy_flt (const char * cx, char * cdst, const int ne,
@@ -53,131 +52,41 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int
5352 const int64_t nmat = ne / (ne00 * ne01);
5453 const int64_t n = ne00 * ne01;
5554
56- int x = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .x ;
57- int y = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
58- int tx = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .x ; // transpose block offset
59- int ty = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
55+ const int x = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .x ;
56+ const int y = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
57+ const int tx = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .x ; // transpose block offset
58+ const int ty = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
6059
6160 __shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];
6261
63- for (int i = 0 ; i < CUDA_CPY_BLOCK_NM; ++i){
62+ #pragma unroll
63+ for (int i = 0 ; i < CUDA_CPY_BLOCK_NM; ++i) {
6464
6565 const unsigned int imat = blockIdx .z * CUDA_CPY_BLOCK_NM + i;
66- if (imat >= nmat)
66+ if (imat >= nmat)
6767 break ;
68- for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
68+
69+ #pragma unroll
70+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
6971 if (x < ne01 && y + j < ne00){
7072 const int row = threadIdx .y +j;
71- const int col = threadIdx .x ^ row;
73+ const int col = threadIdx .x ^ row; // swizzling to avoid bank conflicts
7274 tile[row][col] = src[imat*n + (y+j)*ne01 + x];
7375 }
7476 }
77+
7578 __syncthreads ();
7679
77- for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
78- if (ty + j < ne01 && tx < ne00){
79- const int col = (threadIdx .y +j) ^ threadIdx .x ;
80+ #pragma unroll
81+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82+ if (ty + j < ne01 && tx < ne00) {
83+ const int col = (threadIdx .y +j) ^ threadIdx .x ; // swizzling to avoid bank conflicts
8084 dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx .x ][col];
8185 }
8286 }
8387 }
8488}
8589
86-
87- template <typename T, const int zero_at, const int one_at>
88- static __global__ void cpy_flt_coalesced (const char * cx, char * cdst, const int ne,
89- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
90- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
91- const int nb12, const int nb13) {
92-
93- const T* src = reinterpret_cast <const T*>(cx);
94- T* dst = reinterpret_cast <T*>(cdst);
95-
96- const int64_t n0 = ne00 * ne01;
97- const int64_t n1 = ne10 * ne11;
98-
99- int x = blockIdx .x * CUDA_CPY_TILE_DIM + threadIdx .x ;
100- int y = blockIdx .y * CUDA_CPY_TILE_DIM + threadIdx .y ;
101- int z = blockIdx .z * CUDA_CPY_TILE_DIM;
102-
103- __shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];
104-
105- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
106- if (x < ne00 && y < ne01 && z + k < ne02){
107- // const int row = threadIdx.y+j;
108- // const int col = threadIdx.x ^ row;
109- const int row = threadIdx .y ;
110- const int col = threadIdx .x ;
111- tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
112- }
113- }
114- __syncthreads ();
115-
116- if (zero_at == 2 ){
117- int tx = blockIdx .z * CUDA_CPY_TILE_DIM;
118- if (one_at == 0 ){
119- int ty = blockIdx .x * CUDA_CPY_TILE_DIM;
120- int tz = blockIdx .y * CUDA_CPY_TILE_DIM;
121- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
122- // const int row = threadIdx.y;
123- // const int col = threadIdx.x;
124- // const int col = (threadIdx.y+j) ^ threadIdx.x;
125- if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
126- dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .x ][k][threadIdx .y ];
127- }
128- }
129- } else { // one at 1
130- int tz = blockIdx .x * CUDA_CPY_TILE_DIM;
131- int ty = blockIdx .y * CUDA_CPY_TILE_DIM;
132- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
133- // const int row = threadIdx.y;
134- // const int col = threadIdx.x;
135- // const int col = (threadIdx.y+j) ^ threadIdx.x;
136- if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
137- dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .x ][threadIdx .y ][k];
138- }
139- }
140- }
141- } else if (zero_at == 1 ){
142- int tx = blockIdx .y * CUDA_CPY_TILE_DIM;
143- if (one_at == 0 ){
144- int ty = blockIdx .x * CUDA_CPY_TILE_DIM;
145- int tz = blockIdx .z * CUDA_CPY_TILE_DIM;
146- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
147- // const int row = threadIdx.y;
148- // const int col = threadIdx.x;
149- // const int col = (threadIdx.y+j) ^ threadIdx.x;
150- if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
151- dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[k][threadIdx .x ][threadIdx .y ];
152- }
153- }
154- } else { // one at 2
155- int ty = blockIdx .z * CUDA_CPY_TILE_DIM;
156- int tz = blockIdx .x * CUDA_CPY_TILE_DIM;
157- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
158- // const int row = threadIdx.y;
159- // const int col = threadIdx.x;
160- // const int col = (threadIdx.y+j) ^ threadIdx.x;
161- if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
162- dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .y ][threadIdx .x ][k];
163- }
164- }
165- }
166- } else { // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous
167- int tx = blockIdx .x * CUDA_CPY_TILE_DIM;
168- int ty = blockIdx .z * CUDA_CPY_TILE_DIM;
169- int tz = blockIdx .y * CUDA_CPY_TILE_DIM;
170- for (int k = 0 ; k < CUDA_CPY_TILE_DIM; ++k){
171- // const int row = threadIdx.y;
172- // const int col = threadIdx.x;
173- // const int col = (threadIdx.y+j) ^ threadIdx.x;
174- if (tz + k < ne12 && ty + threadIdx .y < ne11 && tx + threadIdx .x < ne10){
175- dst[(tz + k)*n1 + (ty+threadIdx .y )*ne10 + tx + threadIdx .x ] = tile[threadIdx .y ][k][threadIdx .x ];
176- }
177- }
178- }
179- }
180-
18190static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
18291 float * cdstf = (float *)(cdsti);
18392
@@ -279,72 +188,34 @@ cudaStream_t stream) {
279188 (cx, cdst, ne);
280189}
281190
282- template <typename src_t , typename dst_t , bool coalesced = false >
191+ template <typename src_t , typename dst_t , bool transposed = false >
283192static void ggml_cpy_flt_cuda (
284193 const char * cx, char * cdst, const int ne,
285194 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
286195 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
287196
288- if (coalesced){ // transpose
289- // GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
290- if ( nb00 < nb02 && nb02 <= nb03 ) {
291- dim3 dimGrid ( (ne01 + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
292- (ne00 + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
293- (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1 ) / CUDA_CPY_BLOCK_NM);
294- dim3 dimBlock (CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1 );
295- cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>>
296- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
297- } else {
298- std::vector<std::tuple<int , int , int >> v;
299- v.emplace_back (std::make_tuple (nb00, ne00, 0 ));
300- v.emplace_back (std::make_tuple (nb01, ne01, 1 ));
301- v.emplace_back (std::make_tuple (nb02, ne02, 2 ));
302- std::sort (v.begin (), v.end (),
303- [](auto &a, auto &b) {
304- return std::get<0 >(a) < std::get<0 >(b);
305- });
306- const int ne0_new = std::get<1 >(v[0 ]);
307- const int ne1_new = std::get<1 >(v[1 ]);
308- const int ne2_new = std::get<1 >(v[2 ]);
309- int nidx[3 ];
310- nidx[0 ] = std::get<2 >(v[0 ]);
311- nidx[1 ] = std::get<2 >(v[1 ]);
312- nidx[2 ] = std::get<2 >(v[2 ]);
313- const int zero_at = nidx[2 ] == 0 ? 2 : (nidx[1 ] == 0 ? 1 : 0 );
314- const int one_at = nidx[2 ] == 1 ? 2 : (nidx[1 ] == 1 ? 1 : 0 );
315-
316- dim3 dimGrid ((ne0_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
317- (ne1_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM,
318- (ne2_new + CUDA_CPY_TILE_DIM - 1 ) / CUDA_CPY_TILE_DIM);
319- dim3 dimBlock (CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1 );
320-
321- if (zero_at == 2 ){
322- if (one_at == 1 ){
323- cpy_flt_coalesced<dst_t , 2 , 1 ><<<dimGrid, dimBlock, 0 , stream>>> (
324- cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
325- nb10, nb11, nb12, nb13);
326- }else {
327- cpy_flt_coalesced<dst_t , 2 , 0 ><<<dimGrid, dimBlock, 0 , stream>>> (
328- cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
329- nb10, nb11, nb12, nb13);
330- }
331- } else if (zero_at == 1 ){
332- if (one_at == 2 ){
333- cpy_flt_coalesced<dst_t , 1 , 2 ><<<dimGrid, dimBlock, 0 , stream>>> (
334- cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
335- nb10, nb11, nb12, nb13);
336- }else {
337- cpy_flt_coalesced<dst_t , 1 , 0 ><<<dimGrid, dimBlock, 0 , stream>>> (
338- cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
339- nb10, nb11, nb12, nb13);
340- }
341- } else {
342- cpy_flt_coalesced<dst_t , 0 , 2 ><<<dimGrid, dimBlock, 0 , stream>>> (
343- cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
344- nb10, nb11, nb12, nb13);
345- }
197+ if (transposed) {
198+ GGML_ASSERT (ne == ne00*ne01*ne02); // ne[3] is 1 assumed
199+ int ne00n, ne01n, ne02n;
200+ if (nb00 < nb02) {
201+ ne00n = ne00;
202+ ne01n = ne01;
203+ ne02n = ne02;
204+ } else if (nb00 > nb02) {
205+ ne00n = ne00;
206+ ne01n = ne01*ne02;
207+ ne02n = 1 ;
208+ } else {
209+ GGML_ASSERT (false );
346210 }
347- } else { // other
211+
212+ dim3 dimGrid ( (ne01n + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
213+ (ne00n + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
214+ (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1 ) / CUDA_CPY_BLOCK_NM);
215+ dim3 dimBlock (CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1 );
216+ cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>>
217+ (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
218+ } else {
348219 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
349220 cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
350221 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -514,8 +385,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
514385 char * src1_ddc = (char *) src1->data ;
515386
516387 const bool contiguous_srcs = ggml_is_contiguous (src0) && ggml_is_contiguous (src1);
517- const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous (src0) &&
518- (src0->ne [3 ] == 1 || (src0->nb [2 ] <= src0->nb [3 ] && src0->nb [0 ] < src0->nb [2 ]));
388+ const bool can_be_transposed = nb01 == ggml_element_size (src0) && src0->ne [3 ] == 1 ;
519389
520390 if (src0->type == src1->type && contiguous_srcs) {
521391 GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
@@ -528,7 +398,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
528398 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
529399 }
530400 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
531- if (can_be_transposed){
401+ if (can_be_transposed) {
532402 ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
533403 } else {
534404 ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -571,7 +441,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
571441 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
572442 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
573443 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
574- if (can_be_transposed){
444+ if (can_be_transposed) {
575445 ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
576446 } else {
577447 ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -589,7 +459,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
589459 ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
590460 }
591461 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
592- if (can_be_transposed){
462+ if (can_be_transposed) {
593463 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
594464 } else {
595465 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
0 commit comments