Skip to content

Commit a3784e1

Browse files
committed
WIP: debugging cpy transpose
1 parent cc327f5 commit a3784e1

File tree

4 files changed

+81
-40
lines changed

4 files changed

+81
-40
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
3939

4040

4141
template <typename T>
42-
static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const int ne,
42+
static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne,
4343
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4444
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
4545
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
@@ -58,22 +58,31 @@ static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const i
5858
int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
5959
int ty = blockIdx.x * TILE_DIM + threadIdx.y;
6060

61-
__shared__ T tile[TILE_DIM * TILE_DIM];
61+
// __shared__ T tile[TILE_DIM * TILE_DIM];
62+
__shared__ T tile[TILE_DIM][TILE_DIM];
6263

6364
for(int i = 0; i < BLOCK_NM; ++i){
6465
const unsigned int imat = blockIdx.z * BLOCK_NM + i;
6566
if(imat < nmat){
6667
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
6768
const unsigned int idx = (y+j)*width + x;
68-
if(idx < n)
69-
tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
69+
if(idx < n){
70+
const int row = threadIdx.y+j;
71+
const int col = threadIdx.x ^ row;
72+
// tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
73+
tile[row][col] = src[imat*n + idx];
74+
}
7075
}
7176
__syncthreads();
7277

7378
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
7479
const unsigned int idx = (ty+j)*width + tx;
75-
if(idx < n)
76-
dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
80+
if(idx < n){
81+
// const int row = threadIdx.x;
82+
const int col = (threadIdx.y+j) ^ threadIdx.x;
83+
// dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
84+
dst[imat*n + idx] = tile[threadIdx.x][col];
85+
}
7786
}
7887
}
7988
}
@@ -180,30 +189,33 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
180189
#endif
181190
}
182191

183-
template<typename src_t, typename dst_t>
192+
template<typename src_t, typename dst_t, bool transpose = false>
184193
static void ggml_cpy_flt_cuda(
185194
const char * cx, char * cdst, const int ne,
186195
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
187196
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
188197
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
189-
if constexpr (std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
190-
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>
191-
){
192-
if (ne00 == ne11 && ne01 = ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
198+
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
199+
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
200+
&& transpose){
201+
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
202+
// printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
203+
// if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
204+
// if (transpose) { //transpose
205+
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
193206
dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM,
194207
(ne01 + TILE_DIM - 1) / TILE_DIM,
195208
(ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM );
196209
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
197-
cpy_flt_transpose<cpy_1_flt<dst_t><<<dimGrid, dimBlock, 0, stream>>>
198-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
199-
} else{ // other
200-
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
201-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
202-
}
203-
} else{
210+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
211+
} else{ // other
204212
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
205213
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
206214
}
215+
// } else{
216+
// cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
217+
// (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
218+
// }
207219
}
208220

209221
static void ggml_cpy_f32_q8_0_cuda(
@@ -389,7 +401,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
389401
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
390402
}
391403
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
392-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
404+
if(src1->op_params[10] == 999){
405+
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406+
} else {
407+
ggml_cpy_flt_cuda<float, float, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
408+
}
393409
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
394410
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
395411
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
@@ -420,7 +436,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
420436
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
421437
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
422438
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
423-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439+
if(src1->op_params[10] == 999){
440+
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441+
} else {
442+
ggml_cpy_flt_cuda<half, half, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
443+
}
424444
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
425445
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
426446
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

ggml/src/ggml.c

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,6 +3301,9 @@ static struct ggml_tensor * ggml_cont_impl(
33013301

33023302
result->op = GGML_OP_CONT;
33033303
result->src[0] = a;
3304+
if (a->op == GGML_OP_TRANSPOSE) {
3305+
result->op_params[10] = a->op_params[10]; // preserve the original order
3306+
}
33043307

33053308
return result;
33063309
}
@@ -3614,6 +3617,7 @@ struct ggml_tensor * ggml_transpose(
36143617

36153618
result->op = GGML_OP_TRANSPOSE;
36163619
result->src[0] = a;
3620+
result->op_params[10] = 999; // the transpose flag
36173621

36183622
return result;
36193623
}
@@ -4609,8 +4613,18 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm(
46094613

46104614
struct ggml_tensor *ap, *bp;
46114615
if(layout == 0){
4612-
ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3));
4613-
bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3));
4616+
// ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3));
4617+
// bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3));
4618+
ap = ggml_reshape_4d(ctx,
4619+
ggml_cont(ctx,
4620+
ggml_transpose(ctx,
4621+
ggml_reshape_3d(ctx, a, a->ne[0]*a->ne[1], a->ne[2], a->ne[3]))),
4622+
a->ne[2], a->ne[0], a->ne[1], a->ne[3]);
4623+
bp = ggml_reshape_4d(ctx,
4624+
ggml_cont(ctx,
4625+
ggml_transpose(ctx,
4626+
ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))),
4627+
b->ne[2], b->ne[0], b->ne[1], b->ne[3]);
46144628
} else{
46154629
ap = a;
46164630
bp = b;

tests/test-backend-ops.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,6 +2414,7 @@ struct test_cpy : public test_case {
24142414
const std::array<int64_t, 4> permute_dst;
24152415
bool _src_use_permute;
24162416
bool _dst_use_permute;
2417+
bool is_transpose;
24172418

24182419
std::string vars() override {
24192420
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
@@ -2430,10 +2431,12 @@ struct test_cpy : public test_case {
24302431
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
24312432
std::array<int64_t, 4> ne = {10, 10, 10, 1},
24322433
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
2433-
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
2434+
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
2435+
bool transpose = false)
24342436
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
24352437
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
2436-
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
2438+
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
2439+
is_transpose(transpose) {}
24372440

24382441
ggml_tensor * build_graph(ggml_context * ctx) override {
24392442
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2454,6 +2457,8 @@ struct test_cpy : public test_case {
24542457
}
24552458

24562459
ggml_tensor * out = ggml_cpy(ctx, src, dst);
2460+
if(is_transpose)
2461+
dst->op_params[10] = 999;
24572462
ggml_set_name(out, "out");
24582463

24592464
return out;
@@ -4258,14 +4263,14 @@ struct test_conv_2d_implicit : public test_case {
42584263
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
42594264
ggml_set_name(kernel, "kernel");
42604265

4261-
if (cwhn) {
4262-
// change memory layout to channel-most-contiguous (CWHN),
4263-
// then permute it back so NE matches the original input
4264-
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
4265-
input = ggml_permute(ctx, input, 2, 0, 1, 3);
4266-
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
4267-
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
4268-
}
4266+
// if (cwhn) {
4267+
// // change memory layout to channel-most-contiguous (CWHN),
4268+
// // then permute it back so NE matches the original input
4269+
// input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
4270+
// input = ggml_permute(ctx, input, 2, 0, 1, 3);
4271+
// kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
4272+
// kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
4273+
// }
42694274

42704275
ggml_tensor * out =
42714276
ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn?0:1);
@@ -6831,9 +6836,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
68316836
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
68326837
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
68336838

6834-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
6835-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
6836-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
6839+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
6840+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
6841+
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
6842+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, true));
6843+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, false));
68376844

68386845
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
68396846
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

tests/test-conv2d-implicit.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,10 @@ int main(void)
353353
// std::make_tuple(640,640,52,76,3,3),
354354
// std::make_tuple(640,640,104,152,3,3),
355355
// std::make_tuple(960,320,104,152,3,3),
356-
// std::make_tuple(1280,1280,26,38,3,3),
356+
std::make_tuple(1280,1280,26,38,3,3),
357357
// std::make_tuple(1280,1280,26,38,1,1),
358358
// std::make_tuple(256,128,768,1024,3,3),
359-
std::make_tuple(256,128,768,1024,1,1),
359+
// std::make_tuple(256,128,768,1024,1,1),
360360
// std::make_tuple(1280,640,52,76,3,3),
361361
// std::make_tuple(1920,1280,26,38,3,3),
362362
// std::make_tuple(2560,1280,26,38,3,3),
@@ -451,16 +451,16 @@ int main(void)
451451

452452
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
453453
// for(int i = 0; i < 26*38; i++) {
454-
// // for(int i = 0; i < conv2d_data.size(); i++) {
454+
// for(int i = 0; i < conv2d_data.size(); i++) {
455455
// // float diff = fabs(conv2d_data[i] - wino_data[i]);
456456
// float diff = fabs(im2col_data[i] - wino_data[i]);
457457
// float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
458-
// // if(diff > 1.e-4) {
458+
// if(diff > 0.5) {
459459
// printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
460460
// im2col_data[i], conv2d_data[i],
461461
// wino_data[i], diff, diff1, i);
462462
// // break;
463-
// // }
463+
// }
464464
// }
465465

466466
ggml_free(model.ctx);

0 commit comments

Comments
 (0)