Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f809568
Add initial/naive CUDA kernels for the GGML_OP_SSM_CONV and GGML_OP_S…
jploski Jun 1, 2024
cc365b0
Add GGML_OP_SSM_CONF, GGML_OP_SSM_SCAN to supported ops for CUDA back…
jploski Jun 1, 2024
25f9e65
Update CUDA ops ssm_conv and ssm_scan to match CPU implementation fro…
jploski Jun 2, 2024
64fbd32
Add patch to test cases provided by @compilade; test for ssm_conv fails
jploski Jun 2, 2024
12c913c
Fix backend test for ssm_conv CUDA op not working
jploski Jun 2, 2024
061e520
Update CUDA ops and tests to match implementation from commit 8fb57ac…
jploski Jun 3, 2024
fae826f
Fix failed assertions while running Falcon Mamba
jploski Aug 25, 2024
20d390b
10x performance improve 4 cuda ssm conv & scan
piDack Aug 26, 2024
8dd323b
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 27, 2024
b423a6d
fix ssm_scan numerical error & others update
piDack Aug 27, 2024
40f4787
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 27, 2024
1928967
resolve test-backend-ops conflicts
piDack Aug 27, 2024
21c16fa
fix trailing whitespace
piDack Aug 27, 2024
e53b14f
del debug ingo
piDack Aug 27, 2024
eec0e8c
memory access pattern
piDack Aug 27, 2024
0e682ce
add restrict
piDack Aug 27, 2024
5999d6d
fix conflicts
piDack Aug 28, 2024
316a049
add restrict for dst
piDack Aug 29, 2024
99f2ac1
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 29, 2024
63b6e73
recommit for ci pass
piDack Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/ssm_conv.cuh"
#include "ggml-cuda/ssm_scan.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2303,6 +2305,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2877,6 +2885,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
Expand Down
100 changes: 100 additions & 0 deletions ggml/src/ggml-cuda/ssm_conv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "ssm_conv.cuh"

template <int block_size>
static __global__ void ssm_conv_f32(
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add __restrict__ to the pointers, see #2140 .


// const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;
Copy link
Collaborator

@JohannesGaessler JohannesGaessler Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be aware that this is not a noop. There are special, shared registers for e.g. threadIdx.x and you are taking that data and moving it to regular registers. IIRC correctly the regular registers are slightly faster to access but it will also increase register pressure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.I believe that most GPU registers are more than sufficient to meet the requirements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I worded my previous post poorly. What I meant is that the regular registers are slightly faster (but there is only a limited amount of them).


const int ith = tid;
const int nth = WARP_SIZE;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr * ith;
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;

// {d_conv - 1 + n_t, d_inner, n_seqs}
// sliding window
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches?
// d_inner
#pragma unroll
for (int i1 = 0; i1 < ir; ++i1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir is not known at compile time so this loop cannot actually be unrolled, same for the other loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

// rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Copy link
Collaborator

@compilade compilade Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Good Idea,I am currently testing according to your method.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling at a loss for optimization ideas.

Use one warp per iteration of nc with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum and have the first thread in the warp write back the result.

Copy link
Contributor Author

@piDack piDack Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated.

}
x[i1] = sumf;
}
}

static void ssm_conv_f32_cuda(
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {

const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = n_t;

ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1,
src0_nb0, src0_nb1, src0_nb2,
src1_nb1,
dst,
dst_nb0, dst_nb1, dst_nb2,
nc, ncs, nr);
}

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight

const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch

GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

ssm_conv_f32_cuda(src0_d, src1_d,
src0->nb[0], src0->nb[1], src0->nb[2],
src1->nb[1],
dst_d,
dst->nb[0], dst->nb[1], dst->nb[2],
nc, ncs, nr, n_t, n_s,
stream);
}

3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm_conv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
144 changes: 144 additions & 0 deletions ggml/src/ggml-cuda/ssm_scan.cu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments for ssm_conv.cu largely apply here as well.

Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#include "ssm_scan.cuh"

template <int block_size>
static __global__ void ssm_scan_f32(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * dst,
const int nc, const int nr) {

const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;

const int ith = tid;
const int nth = WARP_SIZE;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;

const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s}

// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }

// d_inner
#pragma unroll
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
}
y[i1] = sumf;
}
}

static void ssm_scan_f32_cuda(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * dst,
const int nc, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {

const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = n_t;

ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1, src2, src3,
src4, src5,
src0_nb1, src0_nb2,
src1_nb0, src1_nb1, src1_nb2, src1_nb3,
src2_nb0, src2_nb1, src2_nb2,
src3_nb1,
src4_nb1, src4_nb2,
src5_nb1, src5_nb2,
dst,
nc, nr);
}

void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C

const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch

GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const float * src2_d = (const float *)src2->data;
const float * src3_d = (const float *)src3->data;
const float * src4_d = (const float *)src4->data;
const float * src5_d = (const float *)src5->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

ssm_scan_f32_cuda(
src0_d, src1_d, src2_d, src3_d,
src4_d, src5_d,
src0->nb[1], src0->nb[2],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
src2->nb[0], src2->nb[1], src2->nb[2],
src3->nb[1],
src4->nb[1], src4->nb[2],
src5->nb[1], src5->nb[2],
dst_d,
nc, nr, n_t, n_s,
stream);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm_scan.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9119,9 +9119,9 @@ static struct ggml_tensor * llm_build_mamba(

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
B = ggml_rms_norm(ctx, B, norm_rms_eps);
C = ggml_rms_norm(ctx, C, norm_rms_eps);
dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
}

// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
Expand Down
64 changes: 62 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ struct test_case {

if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
printf("sentinel mismatch: %s ", t1->name);
ud->ok = false;
return true;
// ud->ok = false;
// return true;
}
}

Expand Down Expand Up @@ -1642,6 +1642,64 @@ struct test_leaky_relu : public test_case {
}
};

// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
const int64_t d_conv;
const int64_t d_inner;
const int64_t n_seq_tokens;
const int64_t n_seqs;

std::string vars() override {
return VARS_TO_STR5(type, d_conv, d_inner, n_seq_tokens, n_seqs);
}

test_ssm_conv(ggml_type type = GGML_TYPE_F32,
int64_t d_conv = 4,
int64_t d_inner = 1536,
int64_t n_seq_tokens = 7,
int64_t n_seqs = 2)
: type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * sx = ggml_new_tensor_3d(ctx, type, d_conv - 1 + n_seq_tokens, d_inner, n_seqs);
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
ggml_tensor * out = ggml_ssm_conv(ctx, sx, c);
return out;
}
};

// GGML_OP_SSM_SCAN
struct test_ssm_scan : public test_case {
const ggml_type type;
const int64_t d_state;
const int64_t d_inner;
const int64_t n_seq_tokens;
const int64_t n_seqs;

std::string vars() override {
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
}

test_ssm_scan(ggml_type type = GGML_TYPE_F32,
int64_t d_state = 16,
int64_t d_inner = 1536,
int64_t n_seq_tokens = 7,
int64_t n_seqs = 2)
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_state, d_inner, n_seqs);
ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, d_state, d_inner);
ggml_tensor * B = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
return out;
}
};

// GGML_OP_FLASH_ATTN_EXT
struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size
Expand Down Expand Up @@ -2433,6 +2491,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_ssm_conv());
test_cases.emplace_back(new test_ssm_scan());

for (int hs : { 64, 80, 128, 256, }) {
for (bool mask : { true, false } ) {
Expand Down