77
88typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
99
10+ const int CUDA_CPY_TILE_DIM_2D = 32 ; // 2D tile dimension for transposed blocks
11+ const int CUDA_CPY_BLOCK_NM = 8 ; // block size of 3rd dimension if available
12+ const int CUDA_CPY_BLOCK_ROWS = 8 ; // block dimension for marching through rows
13+
1014template <cpy_kernel_t cpy_1>
1115static __global__ void cpy_flt (const char * cx, char * cdst, const int ne,
1216 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3539 cpy_1 (cx + x_offset, cdst + dst_offset);
3640}
3741
42+ template <typename T>
43+ static __global__ void cpy_flt_transpose (const char * cx, char * cdst, const int ne,
44+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
45+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
46+ const int nb12, const int nb13) {
47+
48+ const T* src = reinterpret_cast <const T*>(cx);
49+ T* dst = reinterpret_cast <T*>(cdst);
50+
51+ const int64_t nmat = ne / (ne00 * ne01);
52+ const int64_t n = ne00 * ne01;
53+
54+ const int x = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .x ;
55+ const int y = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
56+ const int tx = blockIdx .y * CUDA_CPY_TILE_DIM_2D + threadIdx .x ; // transpose block offset
57+ const int ty = blockIdx .x * CUDA_CPY_TILE_DIM_2D + threadIdx .y ;
58+
59+ __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1 ];
60+
61+ #pragma unroll
62+ for (int i = 0 ; i < CUDA_CPY_BLOCK_NM; ++i) {
63+
64+ const unsigned int imat = blockIdx .z * CUDA_CPY_BLOCK_NM + i;
65+ if (imat >= nmat)
66+ break ;
67+
68+ #pragma unroll
69+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
70+ if (x < ne01 && y + j < ne00){
71+ const int row = threadIdx .y +j;
72+ const int col = threadIdx .x * sizeof (float )/sizeof (T);
73+ T *tile2 = reinterpret_cast <T*>(tile[row]);
74+ tile2[col] = src[imat*n + (y+j)*ne01 + x];
75+ }
76+ }
77+
78+ __syncthreads ();
79+
80+ #pragma unroll
81+ for (int j = 0 ; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82+ if (ty + j < ne01 && tx < ne00) {
83+ const int col = (threadIdx .y +j)*sizeof (float )/sizeof (T);
84+ const T *tile2 = reinterpret_cast <const T*>(tile[threadIdx .x ]);
85+ dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
86+ }
87+ }
88+ }
89+ }
90+
3891static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
3992 float * cdstf = (float *)(cdsti);
4093
@@ -136,15 +189,38 @@ cudaStream_t stream) {
136189 (cx, cdst, ne);
137190}
138191
139- template <typename src_t , typename dst_t >
192+ template <typename src_t , typename dst_t , bool transposed = false >
140193static void ggml_cpy_flt_cuda (
141194 const char * cx, char * cdst, const int ne,
142195 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
143196 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
144197
145- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
146- cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
147- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
198+ if (transposed) {
199+ GGML_ASSERT (ne == ne00*ne01*ne02); // ne[3] is 1 assumed
200+ int ne00n, ne01n, ne02n;
201+ if (nb00 < nb02) {
202+ ne00n = ne00;
203+ ne01n = ne01;
204+ ne02n = ne02;
205+ } else if (nb00 > nb02) {
206+ ne00n = ne00;
207+ ne01n = ne01*ne02;
208+ ne02n = 1 ;
209+ } else {
210+ GGML_ASSERT (false );
211+ }
212+
213+ dim3 dimGrid ( (ne01n + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
214+ (ne00n + CUDA_CPY_TILE_DIM_2D - 1 ) / CUDA_CPY_TILE_DIM_2D,
215+ (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1 ) / CUDA_CPY_BLOCK_NM);
216+ dim3 dimBlock (CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1 );
217+ cpy_flt_transpose<dst_t ><<<dimGrid, dimBlock, 0 , stream>>>
218+ (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
219+ } else {
220+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
221+ cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
222+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
223+ }
148224}
149225
150226static void ggml_cpy_f32_q8_0_cuda (
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
310386 char * src1_ddc = (char *) src1->data ;
311387
312388 const bool contiguous_srcs = ggml_is_contiguous (src0) && ggml_is_contiguous (src1);
389+ const bool can_be_transposed = nb01 == (int64_t )ggml_element_size (src0) && src0->ne [3 ] == 1 ;
313390
314391 if (src0->type == src1->type && contiguous_srcs) {
315392 GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
322399 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
323400 }
324401 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
325- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+ if (can_be_transposed) {
403+ ggml_cpy_flt_cuda<float , float , true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
404+ } else {
405+ ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
406+ }
326407 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
327408 if (contiguous_srcs) {
328409 ggml_cpy_flt_contiguous_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
361442 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
362443 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
363444 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
364- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
445+ if (can_be_transposed) {
446+ ggml_cpy_flt_cuda<half, half, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
447+ } else {
448+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
449+ }
365450 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
366451 if (contiguous_srcs) {
367452 ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
375460 ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376461 }
377462 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
378- ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
463+ if (can_be_transposed) {
464+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
465+ } else {
466+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
467+ }
379468 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
380469 if (contiguous_srcs) {
381470 ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
0 commit comments