Skip to content

Commit d79f3cd

Browse files
remove manual conv class in mamba1 (ml-explore#436)
* remove manual conv class * remove slice * add sanitize * format
1 parent 103877e commit d79f3cd

File tree

1 file changed

+18
-40
lines changed

1 file changed

+18
-40
lines changed

mlx_lm/models/mamba.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,6 @@ def __post_init__(self):
5050
self.use_bcdt_rms = True
5151

5252

53-
class DepthWiseConv1d(nn.Module):
54-
def __init__(self, channels, kernel_size, bias=True, padding=0):
55-
super().__init__()
56-
self.channels = channels
57-
self.kernel_size = kernel_size
58-
self.padding = padding
59-
self.weight = mx.random.normal((self.channels, kernel_size, 1))
60-
self.bias = mx.zeros((channels,)) if bias else None
61-
62-
def __call__(self, x, cache=None):
63-
B, L, C = x.shape
64-
groups, K, _ = self.weight.shape
65-
66-
if cache is not None:
67-
x = mx.concatenate([cache, x], axis=1)
68-
else:
69-
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
70-
71-
y = mx.conv_general(x, self.weight, groups=groups)
72-
73-
if self.bias is not None:
74-
y = y + self.bias
75-
76-
return y, x[:, -K + 1 :, :]
77-
78-
7953
class MambaBlock(nn.Module):
8054
def __init__(self, args: ModelArgs):
8155
super().__init__()
@@ -97,11 +71,13 @@ def __init__(self, args: ModelArgs):
9771
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
9872
)
9973

100-
self.conv1d = DepthWiseConv1d(
101-
channels=self.intermediate_size,
74+
self.conv1d = nn.Conv1d(
75+
in_channels=self.intermediate_size,
76+
out_channels=self.intermediate_size,
10277
kernel_size=self.conv_kernel_size,
78+
groups=self.intermediate_size,
10379
bias=self.use_conv_bias,
104-
padding=self.conv_kernel_size - 1,
80+
padding=0,
10581
)
10682

10783
self.x_proj = nn.Linear(
@@ -148,13 +124,15 @@ def _process_sequence(self, x, conv_cache, state_cache):
148124
B, T, D = x.shape
149125
xz = self.in_proj(x)
150126
x, z = xz.split(indices_or_sections=2, axis=-1)
151-
152-
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
127+
K = self.conv_kernel_size
128+
if conv_cache is not None:
129+
x_full = mx.concatenate([conv_cache, x], axis=1)
130+
else:
131+
x_full = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
132+
conv_out = self.conv1d(x_full)
133+
new_conv_cache = x_full[:, -(K - 1) :, :]
153134
x = nn.silu(conv_out)
154-
155135
A = -mx.exp(self.A_log)
156-
157-
outputs = []
158136
current_state = state_cache
159137
y = []
160138
for t in range(T):
@@ -228,15 +206,15 @@ def __call__(self, inputs: mx.array, cache=None):
228206

229207
return logits
230208

231-
def sanitize(self, weights):
232-
for k, v in weights.items():
233-
if "conv1d.weight" in k and v.shape[-1] != 1:
234-
weights[k] = v.moveaxis(2, 1)
235-
return weights
236-
237209
def make_cache(self):
238210
return [MambaCache() for _ in range(len(self.layers))]
239211

240212
@property
241213
def layers(self):
242214
return self.backbone.layers
215+
216+
def sanitize(self, weights):
217+
for k, v in weights.items():
218+
if "conv1d.weight" in k and v.shape[-1] != 1:
219+
weights[k] = v.moveaxis(2, 1)
220+
return weights

0 commit comments

Comments
 (0)