Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,48 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

template <typename T>
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {

const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);

const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;

int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
int tx = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.x; // transpose block offset
int ty = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.y;

__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];

for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){

const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
if(imat >= nmat)
break;
for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
const int col = threadIdx.x ^ row;
tile[row][col] = src[imat*n + (y+j)*ne01 + x];
}
}
__syncthreads();

for (int j = 0; j < CUDA_CPY_TILE_DIM; j += CUDA_CPY_BLOCK_ROWS){
if(ty + j < ne01 && tx < ne00){
const int col = (threadIdx.y+j) ^ threadIdx.x;
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
}
}
}
}

static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);

Expand Down Expand Up @@ -136,15 +178,23 @@ cudaStream_t stream) {
(cx, cdst, ne);
}

template<typename src_t, typename dst_t>
template<typename src_t, typename dst_t, bool transpose = false>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {

const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
if (transpose){ //transpose
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
(ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
(ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM );
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_BLOCK_ROWS, 1);
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else{ // other
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}

static void ggml_cpy_f32_q8_0_cuda(
Expand Down Expand Up @@ -322,7 +372,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
Expand Down Expand Up @@ -361,7 +415,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
Expand All @@ -375,7 +433,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0)){
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/cpy.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "common.cuh"

#define CUDA_CPY_BLOCK_SIZE 64
#define CUDA_CPY_TILE_DIM 32
#define CUDA_CPY_BLOCK_ROWS 8
#define CUDA_CPY_BLOCK_NM 8

void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);

Expand Down
27 changes: 25 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,7 @@ struct test_cpy : public test_case {
const std::array<int64_t, 4> permute_dst;
bool _src_use_permute;
bool _dst_use_permute;
bool _src_transpose;

std::string vars() override {
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
Expand Down Expand Up @@ -2549,10 +2550,12 @@ struct test_cpy : public test_case {
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
bool transpose_src = false)
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
_src_transpose(transpose_src){}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
Expand All @@ -2564,6 +2567,11 @@ struct test_cpy : public test_case {
ggml_set_name(src, "src_permuted");
}

if (_src_transpose) {
src = ggml_transpose(ctx, src);
ggml_set_name(src, "src_transposed");
}

ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
ggml_set_name(dst, "dst");

Expand Down Expand Up @@ -6513,6 +6521,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));

test_cases.emplace_back(new test_cont());
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
Expand Down Expand Up @@ -7244,6 +7255,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));


test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));

test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));


test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
Expand Down
Loading