Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 225 additions & 7 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <algorithm>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <algorithm>

I think this header is no longer needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also made a few adjustments to reduce bank conflicts for fp16 and bf16 types. Seems getting a small bump in bandwidth for those types.

  CPY(type_src=f32,type_dst=f32,ne=[786432,256,1,1],permute_src=[1,0,2,3],permute_dst=[0,0,0,0],_src_transpose=0):               462 runs -  2247.70 us/run -  1572864 kB/run -  682.52 GB/s
  CPY(type_src=f16,type_dst=f16,ne=[786432,256,1,1],permute_src=[1,0,2,3],permute_dst=[0,0,0,0],_src_transpose=0):               731 runs -  1392.03 us/run -   786432 kB/run -  545.05 GB/s
  CPY(type_src=f16,type_dst=f16,ne=[768,1024,256,1],permute_src=[1,0,2,3],permute_dst=[0,0,0,0],_src_transpose=0):              1118 runs -   922.83 us/run -   786432 kB/run -  822.17 GB/s
  CPY(type_src=bf16,type_dst=bf16,ne=[768,1024,256,1],permute_src=[1,0,2,3],permute_dst=[0,0,0,0],_src_transpose=0):                    1118 runs -   923.18 us/run -   786432 kB/run -  821.86 GB/s
  CPY(type_src=f32,type_dst=f32,ne=[786432,256,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1):               484 runs -  2143.72 us/run -  1572864 kB/run -  715.62 GB/s
  CPY(type_src=f32,type_dst=f32,ne=[768,1024,256,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1):               528 runs -  1952.04 us/run -  1572864 kB/run -  785.89 GB/s
  CPY(type_src=f16,type_dst=f16,ne=[786432,256,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1):               731 runs -  1381.74 us/run -   786432 kB/run -  549.11 GB/s
  CPY(type_src=f16,type_dst=f16,ne=[768,1024,256,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1):              1118 runs -   926.10 us/run -   786432 kB/run -  819.27 GB/s
  CPY(type_src=bf16,type_dst=bf16,ne=[768,1024,256,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1):                    1075 runs -   934.95 us/run -   786432 kB/run -  811.51 GB/s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up with padding instead of swizzling. Padding is better since smem requirement is loose.

#include "cpy.cuh"
#include "dequantize.cuh"
#include "cpy-utils.cuh"
Expand All @@ -7,6 +8,11 @@

typedef void (*cpy_kernel_t)(const char * cx, char * cdst);

const int CUDA_CPY_TILE_DIM = 16;
const int CUDA_CPY_TILE_DIM_2D = 32;
const int CUDA_CPY_BLOCK_NM = 8;
const int CUDA_CPY_BLOCK_ROWS = 8;

template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
Expand Down Expand Up @@ -35,6 +41,143 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

template <typename T>
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {

const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);

const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;

int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;

__shared__ T tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D];

for(int i = 0; i < CUDA_CPY_BLOCK_NM; ++i){

const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
if(imat >= nmat)
break;
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
const int col = threadIdx.x ^ row;
tile[row][col] = src[imat*n + (y+j)*ne01 + x];
}
}
__syncthreads();

for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS){
if(ty + j < ne01 && tx < ne00){
const int col = (threadIdx.y+j) ^ threadIdx.x;
dst[imat*n + (ty+j)*ne00 + tx] = tile[threadIdx.x][col];
}
}
}
}


template <typename T, const int zero_at, const int one_at>
static __global__ void cpy_flt_coalesced(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {

const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);

const int64_t n0 = ne00 * ne01;
const int64_t n1 = ne10 * ne11;

int x = blockIdx.x * CUDA_CPY_TILE_DIM + threadIdx.x;
int y = blockIdx.y * CUDA_CPY_TILE_DIM + threadIdx.y;
int z = blockIdx.z * CUDA_CPY_TILE_DIM;

__shared__ T tile[CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM][CUDA_CPY_TILE_DIM];

for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
if(x < ne00 && y < ne01 && z + k < ne02){
// const int row = threadIdx.y+j;
// const int col = threadIdx.x ^ row;
const int row = threadIdx.y;
const int col = threadIdx.x;
tile[k][row][col] = src[(z+k)*n0 + y*ne00 + x];
}
}
__syncthreads();

if(zero_at == 2){
int tx = blockIdx.z * CUDA_CPY_TILE_DIM;
if(one_at == 0){
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
// const int row = threadIdx.y;
// const int col = threadIdx.x;
// const int col = (threadIdx.y+j) ^ threadIdx.x;
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][k][threadIdx.y];
}
}
} else{ // one at 1
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
int ty = blockIdx.y * CUDA_CPY_TILE_DIM;
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
// const int row = threadIdx.y;
// const int col = threadIdx.x;
// const int col = (threadIdx.y+j) ^ threadIdx.x;
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.x][threadIdx.y][k];
}
}
}
} else if(zero_at == 1){
int tx = blockIdx.y * CUDA_CPY_TILE_DIM;
if(one_at == 0){
int ty = blockIdx.x * CUDA_CPY_TILE_DIM;
int tz = blockIdx.z * CUDA_CPY_TILE_DIM;
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
// const int row = threadIdx.y;
// const int col = threadIdx.x;
// const int col = (threadIdx.y+j) ^ threadIdx.x;
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[k][threadIdx.x][threadIdx.y];
}
}
} else { // one at 2
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
int tz = blockIdx.x * CUDA_CPY_TILE_DIM;
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
// const int row = threadIdx.y;
// const int col = threadIdx.x;
// const int col = (threadIdx.y+j) ^ threadIdx.x;
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][threadIdx.x][k];
}
}
}
} else{ // zero_at_0: means only possible is one_at_2 and two_at_1; otherwise, all contiguous
int tx = blockIdx.x * CUDA_CPY_TILE_DIM;
int ty = blockIdx.z * CUDA_CPY_TILE_DIM;
int tz = blockIdx.y * CUDA_CPY_TILE_DIM;
for(int k = 0; k < CUDA_CPY_TILE_DIM; ++k){
// const int row = threadIdx.y;
// const int col = threadIdx.x;
// const int col = (threadIdx.y+j) ^ threadIdx.x;
if(tz + k < ne12 && ty + threadIdx.y < ne11 && tx + threadIdx.x < ne10){
dst[(tz + k)*n1 + (ty+threadIdx.y)*ne10 + tx + threadIdx.x] = tile[threadIdx.y][k][threadIdx.x];
}
}
}
}

static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);

Expand Down Expand Up @@ -136,15 +279,76 @@ cudaStream_t stream) {
(cx, cdst, ne);
}

template<typename src_t, typename dst_t>
template<typename src_t, typename dst_t, bool coalesced = false>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {

const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
if (coalesced){ //transpose
// GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
if( nb00 < nb02 && nb02 <= nb03 ) {
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else{
std::vector<std::tuple<int, int, int>> v;
v.emplace_back(std::make_tuple(nb00, ne00, 0));
v.emplace_back(std::make_tuple(nb01, ne01, 1));
v.emplace_back(std::make_tuple(nb02, ne02, 2));
std::sort(v.begin(), v.end(),
[](auto &a, auto &b) {
return std::get<0>(a) < std::get<0>(b);
});
const int ne0_new = std::get<1>(v[0]);
const int ne1_new = std::get<1>(v[1]);
const int ne2_new = std::get<1>(v[2]);
int nidx[3];
nidx[0] = std::get<2>(v[0]);
nidx[1] = std::get<2>(v[1]);
nidx[2] = std::get<2>(v[2]);
const int zero_at = nidx[2] == 0 ? 2 : (nidx[1] == 0 ? 1 : 0);
const int one_at = nidx[2] == 1 ? 2 : (nidx[1] == 1 ? 1 : 0);

dim3 dimGrid((ne0_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
(ne1_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
(ne2_new + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM);
dim3 dimBlock(CUDA_CPY_TILE_DIM, CUDA_CPY_TILE_DIM, 1);

if(zero_at == 2){
if(one_at == 1){
cpy_flt_coalesced<dst_t, 2, 1><<<dimGrid, dimBlock, 0, stream>>>(
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
nb10, nb11, nb12, nb13);
}else{
cpy_flt_coalesced<dst_t, 2, 0><<<dimGrid, dimBlock, 0, stream>>>(
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
nb10, nb11, nb12, nb13);
}
} else if(zero_at == 1){
if(one_at == 2){
cpy_flt_coalesced<dst_t, 1, 2><<<dimGrid, dimBlock, 0, stream>>>(
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
nb10, nb11, nb12, nb13);
}else{
cpy_flt_coalesced<dst_t, 1, 0><<<dimGrid, dimBlock, 0, stream>>>(
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
nb10, nb11, nb12, nb13);
}
} else{
cpy_flt_coalesced<dst_t, 0, 2><<<dimGrid, dimBlock, 0, stream>>>(
cx, cdst, ne, ne0_new, ne1_new, ne2_new, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
nb10, nb11, nb12, nb13);
}
}
} else{ // other
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}

static void ggml_cpy_f32_q8_0_cuda(
Expand Down Expand Up @@ -310,6 +514,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data;

const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) &&
(src0->ne[3] == 1 || (src0->nb[2] <= src0->nb[3] && src0->nb[0] < src0->nb[2]));

if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
Expand All @@ -322,7 +528,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(can_be_transposed){
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
Expand Down Expand Up @@ -361,7 +571,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(can_be_transposed){
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
Expand All @@ -375,7 +589,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if(can_be_transposed){
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
Expand Down
Loading
Loading