Skip to content

Commit 230d116

Browse files
bssrdfbssrdf
andauthored
improve CUDA cpy memory bandwidth when copying transposed tensor (ggml-org#16841)
* WIP * added a cpy kernel specific to transposed tensor which uses smem to avoid uncoalesced access; test cases also added shwoing improved memory bandwidth * added BF16 support * more strict check to make sure src0 is a transpose * reformulated to handle more complicated transpose cases * bring back 2D transpose for higher performance * allow build on windows * tranpose copy more shapes * minor tweak * final clean up * restore some test cases * keep only the kernel for true tranposed case; updated with review suggestions * make CI happy * remove headers not needed * reduced bank conflicts for fp16 and bf16 * add missing const* * now bank conflicts free * use padding instead of swizzling --------- Co-authored-by: bssrdf <[email protected]>
1 parent a44d771 commit 230d116

File tree

2 files changed

+126
-10
lines changed

2 files changed

+126
-10
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

10+
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
11+
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
12+
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
13+
1014
template <cpy_kernel_t cpy_1>
1115
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
1216
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3539
cpy_1(cx + x_offset, cdst + dst_offset);
3640
}
3741

42+
template <typename T>
43+
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
44+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
45+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
46+
const int nb12, const int nb13) {
47+
48+
const T* src = reinterpret_cast<const T*>(cx);
49+
T* dst = reinterpret_cast<T*>(cdst);
50+
51+
const int64_t nmat = ne / (ne00 * ne01);
52+
const int64_t n = ne00 * ne01;
53+
54+
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
55+
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
56+
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
57+
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
58+
59+
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
60+
61+
#pragma unroll
62+
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
63+
64+
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
65+
if (imat >= nmat)
66+
break;
67+
68+
#pragma unroll
69+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
70+
if(x < ne01 && y + j < ne00){
71+
const int row = threadIdx.y+j;
72+
const int col = threadIdx.x * sizeof(float)/sizeof(T);
73+
T *tile2 = reinterpret_cast<T*>(tile[row]);
74+
tile2[col] = src[imat*n + (y+j)*ne01 + x];
75+
}
76+
}
77+
78+
__syncthreads();
79+
80+
#pragma unroll
81+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82+
if (ty + j < ne01 && tx < ne00) {
83+
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
84+
const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
85+
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
86+
}
87+
}
88+
}
89+
}
90+
3891
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
3992
float * cdstf = (float *)(cdsti);
4093

@@ -136,15 +189,38 @@ cudaStream_t stream) {
136189
(cx, cdst, ne);
137190
}
138191

139-
template<typename src_t, typename dst_t>
192+
template<typename src_t, typename dst_t, bool transposed = false>
140193
static void ggml_cpy_flt_cuda(
141194
const char * cx, char * cdst, const int ne,
142195
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
143196
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
144197

145-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
146-
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
147-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
198+
if (transposed) {
199+
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
200+
int ne00n, ne01n, ne02n;
201+
if (nb00 < nb02) {
202+
ne00n = ne00;
203+
ne01n = ne01;
204+
ne02n = ne02;
205+
} else if (nb00 > nb02) {
206+
ne00n = ne00;
207+
ne01n = ne01*ne02;
208+
ne02n = 1;
209+
} else {
210+
GGML_ASSERT(false);
211+
}
212+
213+
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
214+
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
215+
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
216+
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
217+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
218+
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
219+
} else {
220+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
221+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
222+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
223+
}
148224
}
149225

150226
static void ggml_cpy_f32_q8_0_cuda(
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
310386
char * src1_ddc = (char *) src1->data;
311387

312388
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
389+
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
313390

314391
if (src0->type == src1->type && contiguous_srcs) {
315392
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
322399
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
323400
}
324401
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
325-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+
if (can_be_transposed) {
403+
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
404+
} else {
405+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
406+
}
326407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
327408
if (contiguous_srcs) {
328409
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
361442
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
362443
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
363444
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
364-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
445+
if (can_be_transposed) {
446+
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
447+
} else {
448+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
449+
}
365450
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
366451
if (contiguous_srcs) {
367452
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
375460
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376461
}
377462
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
378-
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
463+
if (can_be_transposed) {
464+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
465+
} else {
466+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
467+
}
379468
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
380469
if (contiguous_srcs) {
381470
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);

tests/test-backend-ops.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,9 +2576,10 @@ struct test_cpy : public test_case {
25762576
const std::array<int64_t, 4> permute_dst;
25772577
bool _src_use_permute;
25782578
bool _dst_use_permute;
2579+
bool _src_transpose;
25792580

25802581
std::string vars() override {
2581-
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
2582+
return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
25822583
}
25832584

25842585
double max_nmse_err() override {
@@ -2616,10 +2617,12 @@ struct test_cpy : public test_case {
26162617
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
26172618
std::array<int64_t, 4> ne = {10, 10, 10, 1},
26182619
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
2619-
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
2620+
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
2621+
bool transpose_src = false)
26202622
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
26212623
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
2622-
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
2624+
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
2625+
_src_transpose(transpose_src){}
26232626

26242627
ggml_tensor * build_graph(ggml_context * ctx) override {
26252628
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2631,6 +2634,11 @@ struct test_cpy : public test_case {
26312634
ggml_set_name(src, "src_permuted");
26322635
}
26332636

2637+
if (_src_transpose) {
2638+
src = ggml_transpose(ctx, src);
2639+
ggml_set_name(src, "src_transposed");
2640+
}
2641+
26342642
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
26352643
ggml_set_name(dst, "dst");
26362644

@@ -6641,6 +6649,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
66416649
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
66426650
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
66436651
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
6652+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6653+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6654+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6655+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6656+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6657+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6658+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
66446659

66456660
test_cases.emplace_back(new test_cont());
66466661
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -7385,6 +7400,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
73857400
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
73867401
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
73877402

7403+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7404+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7405+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7406+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7407+
7408+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7409+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7410+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7411+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7412+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7413+
7414+
73887415
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
73897416
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
73907417
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)