Skip to content

Commit 3bc7103

Browse files
committed
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
1 parent 7d16e1b commit 3bc7103

File tree

7 files changed

+100
-97
lines changed

7 files changed

+100
-97
lines changed

convert_hf_to_gguf.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
264264

265265
return [(self.map_tensor_name(name), data_torch)]
266266

267+
# TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that)
268+
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
269+
del new_name, bid # unused
270+
271+
return data_torch.squeeze()
272+
267273
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
268274
del name, new_name, bid, n_dims # unused
269275

@@ -295,7 +301,7 @@ def prepare_tensors(self):
295301
break
296302

297303
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
298-
data = data_torch.squeeze().numpy()
304+
data = self.reshape_tensors(data_torch, new_name, bid).numpy()
299305

300306
# if data ends up empty, it means data_torch was a scalar tensor -> restore
301307
if len(data.shape) == 0:
@@ -3063,6 +3069,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
30633069

30643070
yield (new_name, data_torch)
30653071

3072+
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
3073+
if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
3074+
gguf.MODEL_TENSOR.SSM_A,
3075+
gguf.MODEL_TENSOR.SSM_D,
3076+
]):
3077+
# unsqueeze A to use similar shape semantics as Mamba-1
3078+
# (D is also unsqueezed, but for more straightforward broadcast internally)
3079+
return data_torch.reshape((*data_torch.shape, 1))
3080+
3081+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
3082+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
3083+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
3084+
n_group = self.hparams.get("n_groups", 1)
3085+
return data_torch.reshape((n_group, d_inner // n_group))
3086+
3087+
return data_torch.squeeze()
3088+
3089+
30663090

30673091
@Model.register("CohereForCausalLM")
30683092
class CommandR2Model(Model):

ggml/include/ggml.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1828,7 +1828,6 @@ extern "C" {
18281828
struct ggml_tensor * A,
18291829
struct ggml_tensor * B,
18301830
struct ggml_tensor * C,
1831-
struct ggml_tensor * D,
18321831
struct ggml_tensor * ids);
18331832

18341833
// partition into non-overlapping windows with padding if needed

ggml/src/ggml-metal.m

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node(
16491649
struct ggml_tensor * src4 = node->src[4];
16501650
struct ggml_tensor * src5 = node->src[5];
16511651
struct ggml_tensor * src6 = node->src[6];
1652-
struct ggml_tensor * src7 = node->src[7];
16531652

16541653
GGML_ASSERT(src3);
16551654
GGML_ASSERT(src4);
16561655
GGML_ASSERT(src5);
16571656
GGML_ASSERT(src6);
1658-
GGML_ASSERT(src7);
16591657

16601658
size_t offs_src3 = 0;
16611659
size_t offs_src4 = 0;
16621660
size_t offs_src5 = 0;
16631661
size_t offs_src6 = 0;
1664-
size_t offs_src7 = 0;
16651662

16661663
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
16671664
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
16681665
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
16691666
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
1670-
id<MTLBuffer> id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil;
16711667

16721668
const int64_t ne30 = src3->ne[0];
16731669
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
@@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node(
16991695

17001696
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
17011697

1702-
const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);
1703-
1704-
const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70);
1705-
17061698
const int64_t d_state = ne00;
17071699
const int64_t d_inner = ne01;
17081700
const int64_t n_head = ne02;
@@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node(
17271719
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
17281720
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
17291721
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
1730-
[encoder setBuffer:id_src7 offset:offs_src7 atIndex:7];
1731-
[encoder setBuffer:id_dst offset:offs_dst atIndex:8];
1732-
1733-
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:9];
1734-
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10];
1735-
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
1736-
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
1737-
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
1738-
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];
1739-
1740-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15];
1741-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16];
1742-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17];
1743-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
1744-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
1745-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20];
1746-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21];
1747-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22];
1748-
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23];
1749-
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
1750-
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
1751-
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26];
1752-
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
1753-
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
1754-
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29];
1722+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
1723+
1724+
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:8];
1725+
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9];
1726+
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:10];
1727+
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:11];
1728+
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12];
1729+
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13];
1730+
1731+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14];
1732+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15];
1733+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16];
1734+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17];
1735+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18];
1736+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19];
1737+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20];
1738+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21];
1739+
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
1740+
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23];
1741+
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24];
1742+
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25];
1743+
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26];
1744+
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27];
1745+
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28];
17551746
// NOTE: max index is 31
17561747

17571748
if (ne30 == 1) {

ggml/src/ggml-metal.metal

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32(
805805
device const void * src4,
806806
device const void * src5,
807807
device const void * src6,
808-
device const void * src7,
809808
device float * dst,
810809
constant int64_t & d_state,
811810
constant int64_t & d_inner,
@@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32(
838837
const uint64_t nb00 = sizeof(float);
839838
const uint64_t nb10 = sizeof(float);
840839
const uint64_t nb20 = sizeof(float);
841-
const uint64_t nb60 = sizeof(float);
842840

843841
const int64_t nc = d_state;
844842
const int64_t nr = d_inner;
@@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32(
848846

849847
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
850848

851-
device const int32_t * ids = (device const int32_t *) src7;
849+
device const int32_t * ids = (device const int32_t *) src6;
852850

853851
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
854852
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
@@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32(
859857
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh}
860858
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
861859
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
862-
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
863860
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
864861

865862
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
@@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32(
873870
s[i] = state;
874871
}
875872

876-
y[0] = sumf + x[0] * D[0];
873+
y[0] = sumf;
877874

878875
// recurse
879876
s0 = s;
@@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group(
890887
device const void * src4,
891888
device const void * src5,
892889
device const void * src6,
893-
device const void * src7,
894890
device float * dst,
895891
constant int64_t & d_state,
896892
constant int64_t & d_inner,
@@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group(
923919
const uint64_t nb00 = sizeof(float);
924920
const uint64_t nb10 = sizeof(float);
925921
const uint64_t nb20 = sizeof(float);
926-
const uint64_t nb60 = sizeof(float);
927922

928923
const int64_t nc = d_state;
929924
const int64_t nr = d_inner;
@@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group(
933928

934929
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
935930

936-
device const int32_t * ids = (device const int32_t *) src7;
931+
device const int32_t * ids = (device const int32_t *) src6;
937932

938933
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
939934
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
@@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group(
944939
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
945940
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
946941
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
947-
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
948942
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
949943

950944
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
@@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group(
959953
s[i] = state;
960954
}
961955

962-
y[0] = sumf + x[0] * D[0];
956+
y[0] = sumf;
963957

964958
// recurse
965959
s0 = s;

ggml/src/ggml.c

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv(
71817181
const int64_t n_s = sx->ne[2];
71827182

71837183
// TODO: maybe support other strides than 1?
7184-
// FIXME: this is always true?
71857184
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
71867185
GGML_ASSERT(sx->ne[1] == d_inner);
71877186
GGML_ASSERT(n_t >= 0);
@@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan(
72057204
struct ggml_tensor * A,
72067205
struct ggml_tensor * B,
72077206
struct ggml_tensor * C,
7208-
struct ggml_tensor * D,
72097207
struct ggml_tensor * ids) {
72107208
GGML_ASSERT(ggml_is_contiguous(s));
72117209
GGML_ASSERT(ggml_is_contiguous(dt));
@@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan(
72357233
GGML_ASSERT(B->ne[0] == d_state);
72367234
GGML_ASSERT(B->ne[2] == n_seq_tokens);
72377235
GGML_ASSERT(B->ne[3] == n_seqs);
7238-
GGML_ASSERT(D->ne[0] == n_head);
7239-
GGML_ASSERT(ggml_is_vector(D));
72407236
GGML_ASSERT(ids->ne[0] == n_seqs);
72417237
GGML_ASSERT(ggml_is_vector(ids));
72427238
GGML_ASSERT(A->ne[1] == n_head);
@@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan(
72587254
result->src[3] = A;
72597255
result->src[4] = B;
72607256
result->src[5] = C;
7261-
result->src[6] = D;
7262-
result->src[7] = ids;
7257+
result->src[6] = ids;
72637258

72647259
return result;
72657260
}
@@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1621716212
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
1621816213
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
1621916214
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
16220-
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
16221-
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
16215+
const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
1622216216

1622316217
const int ith = params->ith;
1622416218
const int nth = params->nth;
@@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1624016234
GGML_ASSERT(src3->nb[0] == sizeof(float));
1624116235
GGML_ASSERT(src4->nb[0] == sizeof(float));
1624216236
GGML_ASSERT(src5->nb[0] == sizeof(float));
16243-
GGML_ASSERT(src6->nb[0] == sizeof(float));
16244-
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
16237+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
1624516238
// allows optimizing the modulo since n_group should be a power of 2
1624616239
GGML_ASSERT((ng & -ng) == ng);
1624716240

@@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1625216245
const int ih0 = dh*ith;
1625316246
const int ih1 = MIN(ih0 + dh, nh);
1625416247

16255-
const int32_t * ids = (const int32_t *) src7->data;
16248+
const int32_t * ids = (const int32_t *) src6->data;
1625616249

1625716250
for (int i3 = 0; i3 < ns; ++i3) {
1625816251
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
@@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32(
1626416257
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
1626516258
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
1626616259
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
16267-
const float * D = (const float *) ((const char *) src6->data); // {nh}
1626816260
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
1626916261

1627016262
if (src3->ne[0] == 1) {
@@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1632516317
sumf += state * C[ig];
1632616318
s[i] = state;
1632716319
}
16328-
y[ii] = sumf + x[ii] * D[h];
16320+
y[ii] = sumf;
1632916321
}
1633016322
}
1633116323
} else {
@@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1635316345
sumf += state * C[ig];
1635416346
s[i] = state;
1635516347
}
16356-
y[ii] = sumf + x[ii] * D[h];
16348+
y[ii] = sumf;
1635716349
}
1635816350
}
1635916351
}

0 commit comments

Comments
 (0)