@@ -50,32 +50,6 @@ def __post_init__(self):
5050 self .use_bcdt_rms = True
5151
5252
53- class DepthWiseConv1d (nn .Module ):
54- def __init__ (self , channels , kernel_size , bias = True , padding = 0 ):
55- super ().__init__ ()
56- self .channels = channels
57- self .kernel_size = kernel_size
58- self .padding = padding
59- self .weight = mx .random .normal ((self .channels , kernel_size , 1 ))
60- self .bias = mx .zeros ((channels ,)) if bias else None
61-
62- def __call__ (self , x , cache = None ):
63- B , L , C = x .shape
64- groups , K , _ = self .weight .shape
65-
66- if cache is not None :
67- x = mx .concatenate ([cache , x ], axis = 1 )
68- else :
69- x = mx .pad (x , [(0 , 0 ), (K - 1 , 0 ), (0 , 0 )])
70-
71- y = mx .conv_general (x , self .weight , groups = groups )
72-
73- if self .bias is not None :
74- y = y + self .bias
75-
76- return y , x [:, - K + 1 :, :]
77-
78-
7953class MambaBlock (nn .Module ):
8054 def __init__ (self , args : ModelArgs ):
8155 super ().__init__ ()
@@ -97,11 +71,13 @@ def __init__(self, args: ModelArgs):
9771 self .hidden_size , self .intermediate_size * 2 , bias = args .use_bias
9872 )
9973
100- self .conv1d = DepthWiseConv1d (
101- channels = self .intermediate_size ,
74+ self .conv1d = nn .Conv1d (
75+ in_channels = self .intermediate_size ,
76+ out_channels = self .intermediate_size ,
10277 kernel_size = self .conv_kernel_size ,
78+ groups = self .intermediate_size ,
10379 bias = self .use_conv_bias ,
104- padding = self . conv_kernel_size - 1 ,
80+ padding = 0 ,
10581 )
10682
10783 self .x_proj = nn .Linear (
@@ -148,13 +124,15 @@ def _process_sequence(self, x, conv_cache, state_cache):
148124 B , T , D = x .shape
149125 xz = self .in_proj (x )
150126 x , z = xz .split (indices_or_sections = 2 , axis = - 1 )
151-
152- conv_out , new_conv_cache = self .conv1d (x , conv_cache )
127+ K = self .conv_kernel_size
128+ if conv_cache is not None :
129+ x_full = mx .concatenate ([conv_cache , x ], axis = 1 )
130+ else :
131+ x_full = mx .pad (x , [(0 , 0 ), (K - 1 , 0 ), (0 , 0 )])
132+ conv_out = self .conv1d (x_full )
133+ new_conv_cache = x_full [:, - (K - 1 ) :, :]
153134 x = nn .silu (conv_out )
154-
155135 A = - mx .exp (self .A_log )
156-
157- outputs = []
158136 current_state = state_cache
159137 y = []
160138 for t in range (T ):
@@ -228,15 +206,15 @@ def __call__(self, inputs: mx.array, cache=None):
228206
229207 return logits
230208
231- def sanitize (self , weights ):
232- for k , v in weights .items ():
233- if "conv1d.weight" in k and v .shape [- 1 ] != 1 :
234- weights [k ] = v .moveaxis (2 , 1 )
235- return weights
236-
237209 def make_cache (self ):
238210 return [MambaCache () for _ in range (len (self .layers ))]
239211
240212 @property
241213 def layers (self ):
242214 return self .backbone .layers
215+
216+ def sanitize (self , weights ):
217+ for k , v in weights .items ():
218+ if "conv1d.weight" in k and v .shape [- 1 ] != 1 :
219+ weights [k ] = v .moveaxis (2 , 1 )
220+ return weights
0 commit comments