@@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9003
9003
GGML_ASSERT (src4->nb [0 ] == sizeof (float ));
9004
9004
GGML_ASSERT (src5->nb [0 ] == sizeof (float ));
9005
9005
GGML_ASSERT (src6->nb [0 ] == sizeof (int32_t ));
9006
- // allows optimizing the modulo since n_group should be a power of 2
9007
- GGML_ASSERT ((ng & -ng) == ng);
9006
+ GGML_ASSERT (nh % ng == 0 );
9008
9007
9009
9008
// heads per thread
9010
9009
const int dh = (nh + nth - 1 )/nth;
@@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9035
9034
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9036
9035
const float dt_soft_plus = dt[h] <= 20 .0f ? log1pf (expf (dt[h])) : dt[h];
9037
9036
const float dA = expf (dt_soft_plus * A[h]);
9037
+ const int g = h / (nh / ng); // repeat_interleave
9038
9038
9039
9039
// dim
9040
9040
for (int i1 = 0 ; i1 < nr; ++i1) {
@@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9057
9057
// TODO: maybe unroll more?
9058
9058
for (int j = 0 ; j < 1 ; j++) {
9059
9059
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD (s0 + i + j*ggml_f32_epr + ii*nc);
9060
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + (h & (ng - 1 )) *nc);
9061
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + (h & (ng - 1 )) *nc);
9060
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + g *nc);
9061
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + g *nc);
9062
9062
9063
9063
t0 = GGML_F32_VEC_MUL (t0, adA);
9064
9064
t1 = GGML_F32_VEC_MUL (t1, axdt);
@@ -9090,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9090
9090
for (int i = 0 ; i < np; i += GGML_F32_STEP) {
9091
9091
for (int j = 0 ; j < GGML_F32_ARR; j++) {
9092
9092
ax[j] = GGML_F32_VEC_LOAD (s0 + i + j*GGML_F32_EPR + ii*nc);
9093
- ay[j] = GGML_F32_VEC_LOAD (B + i + j*GGML_F32_EPR + (h & (ng - 1 )) *nc);
9094
- az[j] = GGML_F32_VEC_LOAD (C + i + j*GGML_F32_EPR + (h & (ng - 1 )) *nc);
9093
+ ay[j] = GGML_F32_VEC_LOAD (B + i + j*GGML_F32_EPR + g *nc);
9094
+ az[j] = GGML_F32_VEC_LOAD (C + i + j*GGML_F32_EPR + g *nc);
9095
9095
9096
9096
ax[j] = GGML_F32_VEC_MUL (ax[j], adA);
9097
9097
ay[j] = GGML_F32_VEC_MUL (ay[j], axdt);
@@ -9113,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9113
9113
// d_state
9114
9114
for (int i0 = np; i0 < nc; ++i0) {
9115
9115
const int i = i0 + ii*nc;
9116
- const int ig = i0 + (h & (ng - 1 )) *nc;
9116
+ const int ig = i0 + g *nc;
9117
9117
// state = prev_state * dA + dB * x
9118
9118
const float state = (s0[i] * dA) + (B[ig] * x_dt);
9119
9119
// y = rowwise_dotprod(state, C)
@@ -9130,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9130
9130
for (int h = ih0; h < ih1; ++h) {
9131
9131
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9132
9132
const float dt_soft_plus = dt[h] <= 20 .0f ? log1pf (expf (dt[h])) : dt[h];
9133
+ const int g = h / (nh / ng); // repeat_interleave
9133
9134
9134
9135
// dim
9135
9136
for (int i1 = 0 ; i1 < nr; ++i1) {
@@ -9144,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9144
9145
// TODO: what happens when (d_state % svcntw()) != 0?
9145
9146
for (int64_t k = 0 ; k < nc; k += svcntw ()) {
9146
9147
svfloat32_t vA = GGML_F32_VEC_LOAD (&A[h*nc + k]);
9147
- svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k + (h & (ng - 1 )) *nc]);
9148
- svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k + (h & (ng - 1 )) *nc]);
9148
+ svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k + g *nc]);
9149
+ svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k + g *nc]);
9149
9150
svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[ii*nc + k]);
9150
9151
9151
9152
svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
@@ -9165,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9165
9166
// d_state
9166
9167
for (int i0 = 0 ; i0 < nc; ++i0) {
9167
9168
const int i = i0 + ii*nc;
9168
- const int ig = i0 + (h & (ng - 1 )) *nc;
9169
+ const int ig = i0 + g *nc;
9169
9170
// state = prev_state * dA + dB * x
9170
9171
const float state = (s0[i] * expf (dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9171
9172
// y = rowwise_dotprod(state, C)
0 commit comments