-
Notifications
You must be signed in to change notification settings - Fork 6
MLX-LM batching #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
MLX-LM batching #52
Changes from 30 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
7f55a4a
initial commit
shepardxia 32ab601
added naive logprobs and sampling methods
shepardxia 8fb131c
add unit tests
shepardxia 23fffb8
fix format and linter issues
shepardxia 85333c3
fix format and linter issues
shepardxia b6d0e97
fix format and linter issues
shepardxia aec329b
fix format and linter issues
shepardxia 5b1cd3f
fix format and linter issues
shepardxia 8813df7
fix format and linter issues
shepardxia 7ee172d
fixed after Ben's review
shepardxia 23b0d16
fixing imports
shepardxia 94d2163
separate coverage test job
shepardxia 938c040
add benchmarking code and removed unnecessary torch <-> mlx conversions
shepardxia 2d54f1e
revert output back to torch tensors
shepardxia 9021fdc
cache subclassing to fix coverage
shepardxia fe0d412
cache subclassing to fix coverage
shepardxia 4883201
cache subclassing to fix coverage
shepardxia 8eafd77
merge with main updates
shepardxia fb81350
cov
shepardxia 667d3ee
update test
shepardxia a34138c
initial commit
shepardxia cbef988
revising pytest params
shepardxia 3b19338
prevent bf16 batching for now
shepardxia 9444a81
adding coverage
shepardxia 317b972
add additional tests for mlx
shepardxia a2c049c
modify test
shepardxia 612b8ab
modify test
shepardxia da3f1d8
modify test
shepardxia e9625ab
modify test
shepardxia 57b9e99
modify test
shepardxia 3f017f0
Update with token trie kv cache
shepardxia ef7c256
Update with token trie kv cache, fixing tests
shepardxia dee7540
fixing tests
shepardxia 0bc1d92
fixing tests
shepardxia 8f5e575
fixing tests
shepardxia 124c878
fixing tests
shepardxia eddfb94
fixing tests
shepardxia c904a94
final fix
shepardxia 0703a9b
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia 8299ba0
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia 8b58cdf
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,27 @@ | ||
| import asyncio | ||
| from genlm.backend.llm.base import AsyncLM | ||
| from genlm.backend.cache import OutputMLXCache | ||
| from collections import defaultdict | ||
| import torch | ||
|
|
||
| from typing import ( | ||
| Any, | ||
| Optional, | ||
| ) | ||
|
|
||
|
|
||
| try: | ||
| import mlx_lm | ||
| from mlx_lm.generate import generate_step | ||
| from mlx_lm.generate import generate_step, BatchGenerator, wired_limit | ||
| import mlx.core as mx | ||
| from mlx_lm.models import cache | ||
| from mlx_lm.sample_utils import make_sampler | ||
| from mlx_lm.models.cache import ( | ||
| ArraysCache, | ||
| CacheList, | ||
| KVCache, | ||
| RotatingKVCache, | ||
| ) | ||
|
|
||
| HAS_MLX = True | ||
| except ImportError: # pragma: no cover | ||
|
|
@@ -39,17 +48,102 @@ def from_name(cls, *args, **kwargs): # pragma: no cover | |
|
|
||
| else: | ||
|
|
||
| def _to_torch(logprobs): | ||
| """Converts MLX array into torch tensors.""" | ||
| if isinstance(logprobs, mx.array): | ||
| if logprobs.dtype in [mx.bfloat16]: | ||
| logprobs = logprobs.astype(mx.float32) | ||
| return torch.tensor(logprobs) | ||
| elif isinstance(logprobs, (list, tuple)): | ||
| return [_to_torch(lp) for lp in logprobs] | ||
| return logprobs | ||
|
|
||
| def _has_bf16(mlx_lm_model): | ||
| def check(x): | ||
| if isinstance(x, dict): | ||
| return any(check(v) for v in x.values()) | ||
| elif isinstance(x, mx.array): | ||
| return getattr(x, "dtype", None) == mx.bfloat16 | ||
| else: | ||
| return False | ||
|
|
||
| return any( | ||
| check(param) | ||
| for layer in mlx_lm_model.layers | ||
| for param in layer.parameters().values() | ||
| ) | ||
|
|
||
| def _cache_batchable(mlx_lm_model): | ||
| if not hasattr(mlx_lm_model, "make_cache"): | ||
| return True | ||
|
|
||
| cache = mlx_lm_model.make_cache() | ||
| batchable = (CacheList, KVCache, ArraysCache) | ||
| return all( | ||
| isinstance(c, batchable) or (isinstance(c, RotatingKVCache) and c.keep == 0) | ||
| for c in cache | ||
| ) | ||
|
|
||
| def _supports_batching(mlx_lm_model): | ||
| """Return True only if MLX-LM has batching cache support for the model, and does not have bfloat16 parameters.""" | ||
| return _cache_batchable(mlx_lm_model) and not _has_bf16(mlx_lm_model) | ||
|
|
||
| class BatchGeneratorCustom(BatchGenerator): | ||
| """A custom batch generator optimzed for logprobs computation.""" | ||
|
|
||
| def _next(self): | ||
| batch = self.active_batch | ||
| num_active = len(batch) if batch else 0 | ||
| num_to_add = self.completion_batch_size - num_active | ||
| while num_to_add >= self.prefill_batch_size: | ||
| prompts = self.unprocessed_prompts[: self.prefill_batch_size] | ||
| # Finish processing the last examples of the last batch | ||
| if len(prompts) == 0 and num_active > 0: | ||
| break | ||
| batch = self._process_prompts(prompts) | ||
| self.unprocessed_prompts = self.unprocessed_prompts[ | ||
| self.prefill_batch_size : | ||
| ] | ||
| # If there was no active batch, set it | ||
| if self.active_batch is None: | ||
| self.active_batch = batch | ||
| else: | ||
| self.active_batch.extend(batch) | ||
|
|
||
| num_active = len(self.active_batch) | ||
| num_to_add -= len(batch) | ||
|
|
||
| batch = self.active_batch | ||
| y, logprobs = batch.y, batch.logprobs | ||
| batch.y, batch.logprobs = self._step(y[:, None], batch.cache) | ||
| mx.async_eval(batch.y, batch.logprobs) | ||
| return logprobs, batch | ||
|
|
||
| class Query: | ||
shepardxia marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be using a data class here, e.g.,: @DataClass |
||
| """A query to a language model, waiting to be batched.""" | ||
|
|
||
| def __init__(self, prompt, future): | ||
| self.prompt = prompt | ||
| self.future = future | ||
|
|
||
| class AsyncMlxLM(AsyncLM): | ||
| def __init__(self, mlx_lm_model, tokenizer, cache_size=0, cache_opts={}): | ||
| def __init__( | ||
| self, | ||
| mlx_lm_model, | ||
| tokenizer, | ||
| cache_size=0, | ||
| cache_opts={}, | ||
| batch_size=5, | ||
| timeout=0.02, | ||
| **batch_opts, | ||
| ): | ||
| """Initialize an `AsyncMlxLM` instance. | ||
|
|
||
| Args: | ||
| mlx_lm_model (Model): The async MLX LM model instance. | ||
| cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0. | ||
| cache_opts (dict, optional): Additional options to pass to the [`OutputMLXCache`][genlm.backend.cache.OutputMLXCache] constructor. Defaults to {}. | ||
|
|
||
shepardxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| self.mlx_lm_model = mlx_lm_model | ||
| self.tokenizer = tokenizer | ||
| self.cache = ( | ||
|
|
@@ -58,6 +152,12 @@ def __init__(self, mlx_lm_model, tokenizer, cache_size=0, cache_opts={}): | |
| else None | ||
| ) | ||
| self.generation_stream = mx.new_stream(mx.default_device()) | ||
| self.queries = [] | ||
| self.batch_size = batch_size | ||
| self.timeout = timeout | ||
| self.timer = None | ||
| self.batching = _supports_batching(self.mlx_lm_model) and batch_size > 1 | ||
| self.batch_opts = batch_opts | ||
|
|
||
| super().__init__(tokenizer=self.tokenizer) | ||
|
|
||
|
|
@@ -80,6 +180,7 @@ def from_name(cls, model_name, **kwargs): | |
|
|
||
| def clear_cache(self): | ||
| """Clear output cache.""" | ||
| mx.clear_cache() | ||
| if self.cache is not None: | ||
| self.cache.clear() | ||
|
|
||
|
|
@@ -139,64 +240,141 @@ def _step(input_tokens: mx.array): | |
| mx.async_eval(logprobs) | ||
| return logprobs | ||
|
|
||
| async def next_token_logprobs(self, token_ids): | ||
| """Request log probabilities of next token asynchronously with output caching. | ||
| def reset_async_queries(self): | ||
| """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing | ||
| to completion.""" | ||
| self.queries = [] | ||
|
|
||
| def _batch_logits_custom( | ||
| self, | ||
| prompts, | ||
| ): | ||
| """ | ||
| Compute next-token logits for each prompt in a batch using BatchGenerator. | ||
|
|
||
| Args: | ||
| token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model. | ||
| model (nn.Module): The language model. | ||
| prompts (List[List[int]]): Each inner list is a prompt of token IDs. | ||
| verbose (bool): If True, prints progress info. | ||
| kwargs: Passed through to BatchGenerator. | ||
|
|
||
| Returns: | ||
| result (torch.Tensor): Normalized log probability tensor. | ||
| Tuple[List[mx.array], Stats]: A list of logits arrays (one per prompt), | ||
shepardxia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| and BatchGenerator statistics. | ||
| """ | ||
| gen = BatchGeneratorCustom( | ||
| self.mlx_lm_model, stop_tokens=[], **self.batch_opts | ||
| ) | ||
| with wired_limit(self.mlx_lm_model, [self.generation_stream]): | ||
| _ = gen.insert(prompts, 1) | ||
| logprobs, batch = gen.next() | ||
| self.gen = batch | ||
| mx.clear_cache() | ||
| return logprobs | ||
|
|
||
| Warning: | ||
| Do not use `asyncio.run(next_token_logprobs())` as it may interfere with MLX's background loop. | ||
| For synchronous usage, use the `next_token_logprobs_sync()` method instead. | ||
| def batch_evaluate_queries(self): | ||
| """ | ||
| return self.next_token_logprobs_sync(token_ids) | ||
| Process a batch of queued language model queries. | ||
|
|
||
| def next_token_logprobs_sync(self, token_ids): | ||
| """Request log probabilities of next token synchronously. | ||
| This method is called internally when the `batch_size` has been met or the `timeout` has expired. | ||
| """ | ||
|
|
||
| queries, self.queries = self.queries, [] | ||
| if len(queries) == 0: | ||
| return | ||
|
|
||
| query_groups = defaultdict(list) | ||
| for query in queries: | ||
| key = tuple(query.prompt) | ||
| query_groups[key].append(query) | ||
|
|
||
| # Use one representative query from each group | ||
| unique_queries = [group[0] for group in query_groups.values()] | ||
|
|
||
| input_prompts = [q.prompt for q in unique_queries] | ||
shepardxia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.batching: | ||
| results = self._batch_logits_custom( | ||
| input_prompts, | ||
| ) | ||
| else: | ||
| results = [ | ||
| self.next_token_logprobs_sync(q.prompt) for q in unique_queries | ||
| ] | ||
|
|
||
| assert len(results) == len(unique_queries) | ||
|
|
||
| results = _to_torch(results) | ||
| for i, q in enumerate(unique_queries): | ||
| for dup_query in query_groups[tuple(q.prompt)]: | ||
| dup_query.future.set_result(results[i]) | ||
|
|
||
| def add_query(self, query, future): | ||
| """Add a query to be evaluated in the next batch. | ||
|
|
||
| This method is called internally when a `next_token_logprobs` request is made. | ||
|
|
||
| Args: | ||
| token_ids (list[int]): A list of token IDs, representing a prompt to the language model. | ||
| query (list[int]): Token IDs representing the query prompt | ||
| future (asyncio.Future): Future to store the result in | ||
| """ | ||
| self.queries.append(Query(query, future)) | ||
|
|
||
| if self.timer: | ||
| self.timer.cancel() | ||
| self.timer = None | ||
| if len(self.queries) >= self.batch_size: | ||
| self.batch_evaluate_queries() | ||
| else: | ||
| self.timer = asyncio.get_running_loop().call_later( | ||
| self.timeout, lambda: self.batch_evaluate_queries() | ||
| ) | ||
|
|
||
| async def next_token_logprobs(self, token_ids): | ||
| """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`. | ||
|
|
||
| Args: | ||
| token_ids (list[int]): a list of token ids, representing a prompt to the language model. | ||
|
|
||
| Returns: | ||
| (torch.Tensor): Normalized log probability tensor. | ||
| logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt. | ||
| """ | ||
| if not token_ids: | ||
| raise ValueError("Token ids must not be empty") | ||
|
|
||
| key = tuple(token_ids) | ||
|
|
||
| if self.cache is not None and key in self.cache: | ||
| return self.cache[key] | ||
|
|
||
| token_ids_array = mx.array(token_ids) | ||
| logprobs = self._generate_step_custom(token_ids_array) | ||
| logprobs = torch.tensor(logprobs) | ||
| future = asyncio.get_running_loop().create_future() | ||
| self.add_query(token_ids, future) | ||
| logprobs = await future | ||
| if self.cache is not None: | ||
| self.cache[key] = logprobs | ||
| return logprobs | ||
|
|
||
| async def batch_next_token_logprobs(self, token_ids_list): | ||
| """ | ||
| Request log probabilities of next tokens in a batch asynchronously. | ||
| Args: | ||
| token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model. | ||
| Returns: | ||
| (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list. | ||
| """ | ||
| return self.batch_next_token_logprobs_sync(token_ids_list) | ||
| def next_token_logprobs_sync(self, token_ids): | ||
| """Request log probabilities of next token synchronously. | ||
|
|
||
| def batch_next_token_logprobs_sync(self, token_ids_list): | ||
| """ | ||
| Request log probabilities of next tokens in a batch synchronously. | ||
| Args: | ||
| token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model. | ||
| token_ids (list[int]): A list of token IDs, representing a prompt to the language model. | ||
|
|
||
| Returns: | ||
| (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list. | ||
| (torch.Tensor): Normalized log probability tensor. | ||
| """ | ||
| outputs = [] | ||
| for token_ids in token_ids_list: | ||
| outputs.append(self.next_token_logprobs_sync(token_ids)) | ||
| return torch.stack(outputs) | ||
| if not token_ids: | ||
| raise ValueError("Token ids must not be empty") | ||
|
|
||
| key = tuple(token_ids) | ||
|
|
||
| if self.cache is not None and key in self.cache: | ||
| return self.cache[key] | ||
|
|
||
| token_ids_array = mx.array(token_ids) | ||
| logprobs = _to_torch(self._generate_step_custom(token_ids_array)) | ||
| if self.cache is not None: | ||
| self.cache[key] = logprobs | ||
| return logprobs | ||
|
|
||
| async def sample( | ||
| self, | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.