Skip to content

Commit 7380414

Browse files
authored
ggml : fix SSM_SCAN for n_groups > 1 (ggml-org#15625)
1 parent c8d0d14 commit 7380414

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
90039003
GGML_ASSERT(src4->nb[0] == sizeof(float));
90049004
GGML_ASSERT(src5->nb[0] == sizeof(float));
90059005
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
9006-
// allows optimizing the modulo since n_group should be a power of 2
9007-
GGML_ASSERT((ng & -ng) == ng);
9006+
GGML_ASSERT(nh % ng == 0);
90089007

90099008
// heads per thread
90109009
const int dh = (nh + nth - 1)/nth;
@@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
90359034
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
90369035
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
90379036
const float dA = expf(dt_soft_plus * A[h]);
9037+
const int g = h / (nh / ng); // repeat_interleave
90389038

90399039
// dim
90409040
for (int i1 = 0; i1 < nr; ++i1) {
@@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
90579057
// TODO: maybe unroll more?
90589058
for (int j = 0; j < 1; j++) {
90599059
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
9060-
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
9061-
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
9060+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
9061+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
90629062

90639063
t0 = GGML_F32_VEC_MUL(t0, adA);
90649064
t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -9090,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
90909090
for (int i = 0; i < np; i += GGML_F32_STEP) {
90919091
for (int j = 0; j < GGML_F32_ARR; j++) {
90929092
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
9093-
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
9094-
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
9093+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
9094+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
90959095

90969096
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
90979097
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -9113,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
91139113
// d_state
91149114
for (int i0 = np; i0 < nc; ++i0) {
91159115
const int i = i0 + ii*nc;
9116-
const int ig = i0 + (h & (ng - 1))*nc;
9116+
const int ig = i0 + g*nc;
91179117
// state = prev_state * dA + dB * x
91189118
const float state = (s0[i] * dA) + (B[ig] * x_dt);
91199119
// y = rowwise_dotprod(state, C)
@@ -9130,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
91309130
for (int h = ih0; h < ih1; ++h) {
91319131
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
91329132
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9133+
const int g = h / (nh / ng); // repeat_interleave
91339134

91349135
// dim
91359136
for (int i1 = 0; i1 < nr; ++i1) {
@@ -9144,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
91449145
// TODO: what happens when (d_state % svcntw()) != 0?
91459146
for (int64_t k = 0; k < nc; k += svcntw()) {
91469147
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
9147-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
9148-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
9148+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
9149+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
91499150
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
91509151

91519152
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9165,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
91659166
// d_state
91669167
for (int i0 = 0; i0 < nc; ++i0) {
91679168
const int i = i0 + ii*nc;
9168-
const int ig = i0 + (h & (ng - 1))*nc;
9169+
const int ig = i0 + g*nc;
91699170
// state = prev_state * dA + dB * x
91709171
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
91719172
// y = rowwise_dotprod(state, C)

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
129129
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
130130
const int seq_idx = blockIdx.y;
131131

132-
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
132+
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
133133

134134
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
135135
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
19831983
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
19841984
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
19851985
const int64_t i = i0 + i1*nc;
1986+
const int64_t g = ir / (nh / ng); // repeat_interleave
19861987
float s0 = s0_buff[i];
19871988
float s = s_buff[i];
19881989

19891990
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
19901991
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
19911992
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1992-
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1993-
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1993+
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
1994+
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
19941995
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
19951996

19961997
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -2098,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
20982099
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
20992100
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
21002101
const int64_t i = i0 + i1*nc;
2102+
const int64_t g = ir / (nh / ng); // repeat_interleave
21012103
float s0 = s0_buff[i];
21022104
float s = s_buff[i];
21032105

21042106
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
21052107
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
21062108
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2107-
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
2108-
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
2109+
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2110+
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
21092111
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
21102112

21112113
for (int64_t i2 = 0; i2 < n_t; ++i2) {

0 commit comments

Comments
 (0)