Skip to content

Commit 5afac4d

Browse files
committed
WIP
1 parent 851553e commit 5afac4d

File tree

1 file changed

+61
-5
lines changed

1 file changed

+61
-5
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,53 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3535
cpy_1(cx + x_offset, cdst + dst_offset);
3636
}
3737

38+
template <typename T>
39+
static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne,
40+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
41+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
42+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
43+
44+
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
45+
46+
const T* src = reinterpret_cast<const T*>(cx);
47+
T* dst = reinterpret_cast<T*>(cdst);
48+
49+
const int64_t nmat = ne / (ne00 * ne01);
50+
const int64_t n = ne00 * ne01;
51+
int width = ne01;
52+
int height = ne00;
53+
int x = blockIdx.x * TILE_DIM + threadIdx.x;
54+
int y = blockIdx.y * TILE_DIM + threadIdx.y;
55+
int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
56+
int ty = blockIdx.x * TILE_DIM + threadIdx.y;
57+
58+
__shared__ T tile[TILE_DIM][TILE_DIM];
59+
60+
for(int i = 0; i < BLOCK_NM; ++i){
61+
62+
const unsigned int imat = blockIdx.z * BLOCK_NM + i;
63+
if(imat >= nmat)
64+
break;
65+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
66+
if(x < width && y + j < height){
67+
const unsigned int idx = (y+j)*width + x;
68+
const int row = threadIdx.y+j;
69+
const int col = threadIdx.x ^ row;
70+
tile[row][col] = src[imat*n + idx];
71+
}
72+
}
73+
__syncthreads();
74+
75+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
76+
if(ty + j < width && tx < height){
77+
const unsigned int idx = (ty+j)*height + tx;
78+
const int col = (threadIdx.y+j) ^ threadIdx.x;
79+
dst[imat*n + idx] = tile[threadIdx.x][col];
80+
}
81+
}
82+
}
83+
}
84+
3885
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
3986
float * cdstf = (float *)(cdsti);
4087

@@ -136,15 +183,24 @@ cudaStream_t stream) {
136183
(cx, cdst, ne);
137184
}
138185

139-
template<typename src_t, typename dst_t>
186+
template<typename src_t, typename dst_t, bool transpose = false>
140187
static void ggml_cpy_flt_cuda(
141188
const char * cx, char * cdst, const int ne,
142189
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
143190
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
144-
145-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
146-
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
147-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
191+
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
192+
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
193+
&& transpose){
194+
dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM,
195+
(ne00 + TILE_DIM - 1) / TILE_DIM,
196+
(ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM );
197+
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
198+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
199+
} else{ // other
200+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
201+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
202+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
203+
}
148204
}
149205

150206
static void ggml_cpy_f32_q8_0_cuda(

0 commit comments

Comments
 (0)