Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ggml/src/ggml-cann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,11 @@ if (CANN_INSTALL_DIR)
${CANN_INSTALL_DIR}/acllib/include
)

add_subdirectory(kernels)
list(APPEND CANN_LIBRARIES
ascendcl
nnopbase
opapi
acl_op_compiler
ascendc_kernels
)

file(GLOB GGML_SOURCES_CANN "*.cpp")
Expand Down
467 changes: 241 additions & 226 deletions ggml/src/ggml-cann/aclnn_ops.cpp

Large diffs are not rendered by default.

26 changes: 15 additions & 11 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,24 +1704,28 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
}
} break;
case GGML_OP_CPY: {
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
return true;
default:
return false;
ggml_tensor *src = op->src[0];
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
(src->type != GGML_TYPE_F32 &&
src->type != GGML_TYPE_F16)) {
// only support F32 and F16.
return false;
}
}

if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
// unsupport dst is not contiguous.
return false;
}

return true;
} break;
case GGML_OP_CONT: {
// TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) {
Expand Down Expand Up @@ -1762,9 +1766,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
}
case GGML_OP_DUP:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_DUP:
case GGML_OP_REPEAT:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh"
Expand Down Expand Up @@ -2296,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
Expand Down Expand Up @@ -3193,6 +3201,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
Expand Down
151 changes: 151 additions & 0 deletions ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include "ssm-conv.cuh"

template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);

const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);

float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

for (int i = 0; i < n_t; i++) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}

template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int bidz = blockIdx.z;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);

const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);

float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

#pragma unroll
for (int i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
}

static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
const int n_s, cudaStream_t stream) {
const int threads = 128;
GGML_ASSERT(nr % threads == 0);

if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
n_s);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
}
} else {
if (nc == 4) {
const int split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t>
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
}
}
}

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight

const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch

GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));

const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm-conv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading
Loading