Skip to content

Commit 103877e

Browse files
authored
some cleanup + tests towards batching (ml-explore#430)
1 parent 64574e1 commit 103877e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+939
-546
lines changed

mlx_lm/evaluate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def _score_fn(self, inputs, cache: Optional[Any] = None, step_size: int = 2048):
107107
T = inp.shape[1]
108108

109109
offset = cache[0].offset
110-
mask = create_causal_mask(T, offset, lengths=lengths)
111-
112-
logits = self._model(inp, cache=cache, mask=mask)
110+
logits = self._model(inp, cache=cache)
113111
log_probs = nn.log_softmax(logits.astype(mx.float32))
114112

115113
score = mx.take_along_axis(

mlx_lm/generate.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,6 @@ def stream_generate(
686686
prompt, model, draft_model, **kwargs
687687
)
688688
with wired_limit(model, [generation_stream]):
689-
detokenizer.reset()
690689
tic = time.perf_counter()
691690
for n, (token, logprobs, from_draft) in enumerate(token_generator):
692691
if n == 0:
@@ -731,7 +730,6 @@ def generate(
731730
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
732731
prompt: Union[str, List[int]],
733732
verbose: bool = False,
734-
formatter: Optional[Callable] = None,
735733
**kwargs,
736734
) -> str:
737735
"""
@@ -746,11 +744,6 @@ def generate(
746744
kwargs: The remaining options get passed to :func:`stream_generate`.
747745
See :func:`stream_generate` for more details.
748746
"""
749-
if formatter is not None:
750-
print(
751-
"[Warning] Text formatting is deprecated and no longer used. "
752-
"The argument will be removed in a future version."
753-
)
754747
if verbose:
755748
print("=" * 10)
756749

mlx_lm/models/afm7.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,18 +350,16 @@ def __init__(self, args: ModelArgs):
350350
def __call__(
351351
self,
352352
inputs: mx.array,
353-
mask: mx.array = None,
354353
cache=None,
355354
):
356355
h = self.embedding(inputs)
357356

358-
if mask is None:
359-
mask = create_attention_mask(h, cache)
360-
361357
if cache is None:
362358
cache = [None] * len(self.layers)
363359
cache[-1] = ConcatenateKVCache()
364360

361+
mask = create_attention_mask(h, cache[0])
362+
365363
for layer, c in zip(self.layers, cache):
366364
h = layer(h, mask, cache=c)
367365

@@ -382,10 +380,9 @@ def __init__(self, args: ModelArgs):
382380
def __call__(
383381
self,
384382
inputs: mx.array,
385-
mask: mx.array = None,
386383
cache=None,
387384
):
388-
out = self.model(inputs, mask, cache)
385+
out = self.model(inputs, cache)
389386
out = self.model.embedding.as_linear(out)
390387
return out
391388

mlx_lm/models/apertus.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,15 @@ def __init__(self, args: ModelArgs):
177177
def __call__(
178178
self,
179179
inputs: mx.array,
180-
mask: Optional[mx.array] = None,
181180
cache: Optional[Any] = None,
182181
) -> mx.array:
183182
h = self.embed_tokens(inputs)
184183

185-
if mask is None:
186-
mask = create_attention_mask(h, cache)
187-
188184
if cache is None:
189185
cache = [None] * len(self.layers)
190186

187+
mask = create_attention_mask(h, cache[0])
188+
191189
for layer, c in zip(self.layers, cache):
192190
h = layer(h, mask=mask, cache=c)
193191

@@ -205,10 +203,9 @@ def __init__(self, args: ModelArgs):
205203
def __call__(
206204
self,
207205
inputs: mx.array,
208-
mask: Optional[mx.array] = None,
209206
cache: Optional[Any] = None,
210207
) -> mx.array:
211-
out = self.model(inputs, mask, cache)
208+
out = self.model(inputs, cache)
212209
return self.lm_head(out)
213210

214211
def sanitize(self, weights):

mlx_lm/models/baichuan_m1.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def __call__(
9696
k = k.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
9797
v = v.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
9898

99-
if cache is not None:
99+
if cache is None:
100+
cache = (None, None)
101+
102+
if cache[0] is not None:
100103
offset = cache[1].offset
101104
last_k, last_v = cache[0][0], cache[0][1]
102105
else:
@@ -110,7 +113,7 @@ def __call__(
110113
q = self.rope(q, offset=offset)
111114
k = self.rope(k, offset=offset)
112115

113-
if cache is not None:
116+
if cache[0] is not None:
114117
k, v = cache[1].update_and_fetch(k, v)
115118
if L > 0:
116119
cache[0][0] = k_init[:, :, -1:, :]
@@ -167,17 +170,40 @@ def __init__(self, config: ModelArgs):
167170
self.layers = [DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
168171
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
169172

170-
def __call__(
171-
self, inputs: mx.array, mask: mx.array = None, cache: Any = None
172-
) -> mx.array:
173+
self.sliding_window = config.sliding_window
174+
self.first_swa_idx = None
175+
if config.sliding_window_layers:
176+
self.first_swa_idx = config.sliding_window_layers[0]
177+
178+
self.first_global_idx = None
179+
self.swa_layers = set(config.sliding_window_layers)
180+
for i in range(config.num_hidden_layers):
181+
if i in self.swa_layers:
182+
continue
183+
self.first_global_idx = i
184+
break
185+
186+
def __call__(self, inputs: mx.array, cache: Any = None) -> mx.array:
173187
x = self.embed_tokens(inputs)
174-
if mask is None:
175-
if cache is not None:
176-
c = [cache[0][1]]
177-
mask = create_attention_mask(x, c)
188+
178189
if cache is None:
179-
cache = [None] * len(self.layers)
180-
for layer, c in zip(self.layers, cache):
190+
cache = [(None, None)] * len(self.layers)
191+
192+
if self.first_global_idx is None:
193+
c_global = None
194+
else:
195+
c_global = cache[self.first_global_idx][1]
196+
197+
if self.first_swa_idx is None:
198+
c_swa = None
199+
else:
200+
c_swa = cache[self.first_swa_idx][1]
201+
202+
global_mask = create_attention_mask(x, c_global)
203+
swa_mask = create_attention_mask(x, c_swa, window_size=self.sliding_window)
204+
205+
for l, (layer, c) in enumerate(zip(self.layers, cache)):
206+
mask = swa_mask if l in self.swa_layers else global_mask
181207
x = layer(x, mask, c)
182208
return self.norm(x)
183209

@@ -215,10 +241,8 @@ def sanitize(self, weights: dict) -> dict:
215241
weights["lm_head.weight"] = w
216242
return weights
217243

218-
def __call__(
219-
self, inputs: mx.array, mask: mx.array = None, cache: Any = None
220-
) -> mx.array:
221-
outputs = self.model(inputs, mask, cache)
244+
def __call__(self, inputs: mx.array, cache: Any = None) -> mx.array:
245+
outputs = self.model(inputs, cache)
222246
return self.lm_head(outputs)
223247

224248
@property

mlx_lm/models/bailing_moe.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,17 +239,15 @@ def __init__(self, args: ModelArgs):
239239
def __call__(
240240
self,
241241
inputs: mx.array,
242-
mask: Optional[mx.array] = None,
243242
cache: Optional[Any] = None,
244243
):
245244
h = self.word_embeddings(inputs)
246245

247-
if mask is None:
248-
mask = create_attention_mask(h, cache)
249-
250246
if cache is None:
251247
cache = [None] * len(self.layers)
252248

249+
mask = create_attention_mask(h, cache[0])
250+
253251
for layer, c in zip(self.layers, cache):
254252
h = layer(h, mask, c)
255253

@@ -268,10 +266,9 @@ def __init__(self, args: ModelArgs):
268266
def __call__(
269267
self,
270268
inputs: mx.array,
271-
mask: mx.array = None,
272269
cache=None,
273270
):
274-
h = self.model(inputs, mask, cache)
271+
h = self.model(inputs, cache)
275272
return self.lm_head(h)
276273

277274
def sanitize(self, weights):

mlx_lm/models/base.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import mlx.core as mx
88
from mlx.utils import tree_map
99

10-
from .cache import QuantizedKVCache
11-
1210

1311
@dataclass
1412
class BaseModelArgs:
@@ -43,26 +41,16 @@ def create_causal_mask(
4341

4442

4543
def create_attention_mask(
46-
h: mx.array, cache: Optional[Any] = None, return_array: bool = False
44+
h, cache=None, window_size: Optional[int] = None, return_array: bool = False
4745
):
48-
T = h.shape[1]
49-
if T > 1:
50-
offset = 0
51-
window_size = None
52-
if cache is not None and cache[0] is not None:
53-
c = cache[0]
54-
offset = c.offset
55-
if hasattr(c, "max_size"):
56-
window_size = c.max_size
57-
offset = min(window_size, offset)
58-
return_array = return_array or offset + T > window_size
59-
if return_array:
60-
return create_causal_mask(T, offset, window_size=window_size)
61-
else:
62-
return "causal"
63-
else:
64-
mask = None
65-
return mask
46+
N = h.shape[1]
47+
if cache and hasattr(cache, "make_mask"):
48+
return cache.make_mask(N, return_array=return_array, window_size=window_size)
49+
if N == 1:
50+
return None
51+
if return_array or (window_size and N > window_size):
52+
return create_causal_mask(N, window_size=window_size)
53+
return "causal"
6654

6755

6856
def quantized_scaled_dot_product_attention(
@@ -117,7 +105,7 @@ def scaled_dot_product_attention(
117105
scale: float,
118106
mask: Optional[mx.array],
119107
) -> mx.array:
120-
if isinstance(cache, QuantizedKVCache):
108+
if hasattr(cache, "bits"):
121109
return quantized_scaled_dot_product_attention(
122110
queries,
123111
keys,

mlx_lm/models/bitnet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,15 @@ def __init__(self, args: ModelArgs):
163163
def __call__(
164164
self,
165165
inputs: mx.array,
166-
mask: mx.array = None,
167166
cache=None,
168167
):
169168
h = self.embed_tokens(inputs)
170169

171-
if mask is None:
172-
mask = create_attention_mask(h, cache)
173-
174170
if cache is None:
175171
cache = [None] * len(self.layers)
176172

173+
mask = create_attention_mask(h, cache[0])
174+
177175
for layer, c in zip(self.layers, cache):
178176
h = layer(h, mask, cache=c)
179177

@@ -192,10 +190,9 @@ def __init__(self, args: ModelArgs):
192190
def __call__(
193191
self,
194192
inputs: mx.array,
195-
mask: mx.array = None,
196193
cache=None,
197194
):
198-
out = self.model(inputs, mask, cache)
195+
out = self.model(inputs, cache)
199196
if self.args.tie_word_embeddings:
200197
out = self.model.embed_tokens.as_linear(out)
201198
else:

mlx_lm/models/cache.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import mlx.nn as nn
77
from mlx.utils import tree_flatten, tree_map, tree_unflatten
88

9+
from .base import create_causal_mask
10+
911

1012
def make_prompt_cache(
1113
model: nn.Module,
@@ -106,6 +108,17 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
106108
return [c.trim(num_tokens) for c in cache][0]
107109

108110

111+
def create_attention_mask(
112+
N: int, offset: int, return_array: bool, window_size: Optional[int]
113+
):
114+
if N == 1:
115+
return None
116+
if return_array:
117+
return create_causal_mask(N, offset, window_size=window_size)
118+
else:
119+
return "causal"
120+
121+
109122
class _BaseCache:
110123
@property
111124
def state(self):
@@ -170,6 +183,9 @@ def trim(self, n):
170183
self.offset -= n
171184
return n
172185

186+
def make_mask(self, *args, **kwargs):
187+
return create_attention_mask(*args, offset=self.offset, **kwargs)
188+
173189

174190
class QuantizedKVCache(_BaseCache):
175191
def __init__(self, group_size: int = 64, bits: int = 8):
@@ -252,6 +268,9 @@ def trim(self, n):
252268
self.offset -= n
253269
return n
254270

271+
def make_mask(self, *args, **kwargs):
272+
return create_attention_mask(*args, offset=self.offset, **kwargs)
273+
255274

256275
class KVCache(_BaseCache):
257276
def __init__(self):
@@ -317,6 +336,9 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
317336
)
318337
return quant_cache
319338

339+
def make_mask(self, *args, **kwargs):
340+
return create_attention_mask(*args, offset=self.offset, **kwargs)
341+
320342

321343
class RotatingKVCache(_BaseCache):
322344

@@ -460,6 +482,29 @@ def trim(self, n):
460482
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
461483
raise NotImplementedError("RotatingKVCache Quantization NYI")
462484

485+
def make_mask(
486+
self, N: int, window_size: Optional[int] = None, return_array: bool = False
487+
):
488+
if N > 1:
489+
window_size = window_size or self.max_size
490+
offset = min(self.max_size, self.offset)
491+
if offset + N > window_size or return_array:
492+
return create_causal_mask(N, offset, window_size=window_size)
493+
else:
494+
return "causal"
495+
else:
496+
if window_size is None:
497+
return None
498+
# May need a mask for when window_size < max_size
499+
if self.offset >= window_size and self.max_size > window_size:
500+
idx = self._idx
501+
if idx >= self.max_size:
502+
idx = 0
503+
mask_size = min(self.max_size, self.offset)
504+
mask = mx.arange(mask_size) >= (mask_size - window_size)
505+
mask = mx.roll(mask, shift=idx + 1)
506+
return mask[:, None]
507+
463508

464509
class ArraysCache(_BaseCache):
465510
def __init__(self, size):

0 commit comments

Comments
 (0)