Skip to content

Commit 6d12288

Browse files
author
bssrdf
committed
WIP: fixed a bug in cpy transpos index computation
1 parent a3784e1 commit 6d12288

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co
5252
const int64_t nmat = ne /(ne00 * ne01);
5353
const int64_t n = ne00 * ne01;
5454
// const int64_t n = ne01 * ne02;
55-
int width = gridDim.x * TILE_DIM;
55+
int width = ne01;
5656
int x = blockIdx.x * TILE_DIM + threadIdx.x;
5757
int y = blockIdx.y * TILE_DIM + threadIdx.y;
5858
int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
@@ -194,21 +194,22 @@ static void ggml_cpy_flt_cuda(
194194
const char * cx, char * cdst, const int ne,
195195
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
196196
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
197-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
198-
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
197+
198+
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
199199
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
200200
&& transpose){
201201
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
202202
// printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
203203
// if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
204204
// if (transpose) { //transpose
205205
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
206-
dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM,
207-
(ne01 + TILE_DIM - 1) / TILE_DIM,
206+
dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM,
207+
(ne00 + TILE_DIM - 1) / TILE_DIM,
208208
(ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM );
209209
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
210210
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
211211
} else{ // other
212+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
212213
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
213214
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
214215
}
@@ -401,7 +402,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
401402
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
402403
}
403404
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
404-
if(src1->op_params[10] == 999){
405+
if(src0->op_params[10] == 999){
405406
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406407
} else {
407408
ggml_cpy_flt_cuda<float, float, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@@ -436,7 +437,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
436437
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
437438
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
438439
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
439-
if(src1->op_params[10] == 999){
440+
if(src0->op_params[10] == 999){
440441
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441442
} else {
442443
ggml_cpy_flt_cuda<half, half, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);

ggml/src/ggml.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,9 +3301,6 @@ static struct ggml_tensor * ggml_cont_impl(
33013301

33023302
result->op = GGML_OP_CONT;
33033303
result->src[0] = a;
3304-
if (a->op == GGML_OP_TRANSPOSE) {
3305-
result->op_params[10] = a->op_params[10]; // preserve the original order
3306-
}
33073304

33083305
return result;
33093306
}

0 commit comments

Comments
 (0)