Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 111 additions & 75 deletions ggml/src/ggml-cuda/binbcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}



template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int ne0, const int ne1, const int ne2, const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;

if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3_fastdiv,
const uint3 ne10_fastdiv,
const uint3 ne11_fastdiv,
const uint3 ne12_fastdiv,
const uint3 ne13_fastdiv,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3_fastdiv);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3_fastdiv.z);

if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3_fastdiv.z) {
return;
}

const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const uint32_t i11 = fastmodulo(i1, ne11_fastdiv);
const uint32_t i12 = fastmodulo(i2, ne12_fastdiv);
const uint32_t i13 = fastmodulo(i3, ne13_fastdiv);

const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
Expand All @@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10_fastdiv);

float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
Expand All @@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
}
}

template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int ne0, const int ne1, const int ne2,const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs ... src1s) {
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const uint3 ne0_fastdiv,
const uint3 ne1_fastdiv,
const uint3 ne2_fastdiv,
const uint32_t ne3,
const uint3 ne012_fastdiv,
const uint3 ne01_fastdiv,
const uint3 ne10_fastdiv,
const uint3 ne11_fastdiv,
const uint3 ne12_fastdiv,
const uint3 ne13_fastdiv,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

const int i3 = i/(ne2*ne1*ne0);
const int i2 = (i/(ne1*ne0)) % ne2;
const int i1 = (i/ne0) % ne1;
const int i0 = i % ne0;
const uint32_t i3 = fastdiv(i, ne012_fastdiv);
const uint32_t i2 = fastmodulo(fastdiv(i, ne01_fastdiv), ne2_fastdiv);
const uint32_t i1 = fastmodulo(fastdiv(i, ne0_fastdiv), ne1_fastdiv);
const uint32_t i0 = fastmodulo(i, ne0_fastdiv);

if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
if (i0 >= ne0_fastdiv.z || i1 >= ne1_fastdiv.z || i2 >= ne2_fastdiv.z || i3 >= ne3) {
return;
}

const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const int i11 = fastmodulo(i1, ne11_fastdiv);
const int i12 = fastmodulo(i2, ne12_fastdiv);
const int i13 = fastmodulo(i3, ne13_fastdiv);

const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
Expand All @@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

const int i10 = i0 % ne10;
const int i10 = fastmodulo(i0, ne10_fastdiv);

float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
Expand Down Expand Up @@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);

int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];

size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
Expand Down Expand Up @@ -233,48 +264,53 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);

dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);


const uint3 ne10_fastdiv = init_fastdiv_values((uint32_t) cne1[0]);
const uint3 ne11_fastdiv = init_fastdiv_values((uint32_t) cne1[1]);
const uint3 ne12_fastdiv = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) cne1[3]);

if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
const uint3 ne012_fastdiv = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 ne01_fastdiv = init_fastdiv_values((uint32_t) (ne0 * ne1));
if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10_fastdiv, ne11_fastdiv, ne12_fastdiv,
ne13_fastdiv,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv,
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
}
}
}
Expand Down
Loading