@@ -96,7 +96,10 @@ def __call__(
9696 k = k .reshape (B , L , self .num_kv_heads , self .head_dim ).transpose (0 , 2 , 1 , 3 )
9797 v = v .reshape (B , L , self .num_kv_heads , self .head_dim ).transpose (0 , 2 , 1 , 3 )
9898
99- if cache is not None :
99+ if cache is None :
100+ cache = (None , None )
101+
102+ if cache [0 ] is not None :
100103 offset = cache [1 ].offset
101104 last_k , last_v = cache [0 ][0 ], cache [0 ][1 ]
102105 else :
@@ -110,7 +113,7 @@ def __call__(
110113 q = self .rope (q , offset = offset )
111114 k = self .rope (k , offset = offset )
112115
113- if cache is not None :
116+ if cache [ 0 ] is not None :
114117 k , v = cache [1 ].update_and_fetch (k , v )
115118 if L > 0 :
116119 cache [0 ][0 ] = k_init [:, :, - 1 :, :]
@@ -167,17 +170,40 @@ def __init__(self, config: ModelArgs):
167170 self .layers = [DecoderLayer (config , i ) for i in range (config .num_hidden_layers )]
168171 self .norm = nn .RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
169172
170- def __call__ (
171- self , inputs : mx .array , mask : mx .array = None , cache : Any = None
172- ) -> mx .array :
173+ self .sliding_window = config .sliding_window
174+ self .first_swa_idx = None
175+ if config .sliding_window_layers :
176+ self .first_swa_idx = config .sliding_window_layers [0 ]
177+
178+ self .first_global_idx = None
179+ self .swa_layers = set (config .sliding_window_layers )
180+ for i in range (config .num_hidden_layers ):
181+ if i in self .swa_layers :
182+ continue
183+ self .first_global_idx = i
184+ break
185+
186+ def __call__ (self , inputs : mx .array , cache : Any = None ) -> mx .array :
173187 x = self .embed_tokens (inputs )
174- if mask is None :
175- if cache is not None :
176- c = [cache [0 ][1 ]]
177- mask = create_attention_mask (x , c )
188+
178189 if cache is None :
179- cache = [None ] * len (self .layers )
180- for layer , c in zip (self .layers , cache ):
190+ cache = [(None , None )] * len (self .layers )
191+
192+ if self .first_global_idx is None :
193+ c_global = None
194+ else :
195+ c_global = cache [self .first_global_idx ][1 ]
196+
197+ if self .first_swa_idx is None :
198+ c_swa = None
199+ else :
200+ c_swa = cache [self .first_swa_idx ][1 ]
201+
202+ global_mask = create_attention_mask (x , c_global )
203+ swa_mask = create_attention_mask (x , c_swa , window_size = self .sliding_window )
204+
205+ for l , (layer , c ) in enumerate (zip (self .layers , cache )):
206+ mask = swa_mask if l in self .swa_layers else global_mask
181207 x = layer (x , mask , c )
182208 return self .norm (x )
183209
@@ -215,10 +241,8 @@ def sanitize(self, weights: dict) -> dict:
215241 weights ["lm_head.weight" ] = w
216242 return weights
217243
218- def __call__ (
219- self , inputs : mx .array , mask : mx .array = None , cache : Any = None
220- ) -> mx .array :
221- outputs = self .model (inputs , mask , cache )
244+ def __call__ (self , inputs : mx .array , cache : Any = None ) -> mx .array :
245+ outputs = self .model (inputs , cache )
222246 return self .lm_head (outputs )
223247
224248 @property
0 commit comments