Skip to content

Comments

MLX-LM batching #52

Merged
shepardxia merged 41 commits intomainfrom
mlx-lm-batching
Nov 13, 2025
Merged

MLX-LM batching #52
shepardxia merged 41 commits intomainfrom
mlx-lm-batching

Conversation

@shepardxia
Copy link
Contributor

@shepardxia shepardxia commented Oct 23, 2025

Adding batching function to MLX-LM backend with KV caching.

@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@shepardxia shepardxia marked this pull request as ready for review October 29, 2025 21:21
@shepardxia
Copy link
Contributor Author

@benlebrun PR ready for review!

Copy link
Member

@benlebrun benlebrun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, left a few comments---I think this can be cleaned up a bit, but generally looks good.

One thing to consider: we don't actually need the timeout functionality for async batching. The basic idea is to have a queue and a background task that eagerly pulls from the queue and processes all requests together in a batch.

Because of the way that the asyncio scheduler works, all concurrent requests will be batched together. The background task grabs one item with an await, and then drains the queue with get_nowait() to form the batch. Because other coroutines that call _queue_request will run while the background loop awaits the first get(), those concurrent requests will land in the same batch.

A rough sketch looks like this:

class AutoBatchedSketch:

    def __init__(self):
        self._queue = None
        self._task = None
    
    def _start(self):
        if not self._task or self._task.done():
            self._queue = asyncio.Queue()
            self._task = asyncio.create_task(self._background_loop())

    def _queue_request(self, request):
        if not self._task or self._task.done():
            self._start()

        future = asyncio.get_running_loop().create_future()
        self._queue.put_nowait((request, future))
        return future

    async def next_token_logprobs(self, token_ids) -> Any:
        """ Public API. Enqueue a request and await its result. """
        return await self._queue_request(token_ids)

    async def _background_loop(self):
        while True:
            try:
                requests = [await self._queue.get()]

                try:
                    while True: 
                        requests.append(self._queue.get_nowait())
                except asyncio.QueueEmpty:
                    pass

                    inputs, futures = zip(*requests)
                    results = self._batch_call(inputs)
                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for _, future in requests:
                    if not future.done():
                        future.set_exception(e)
                raise

Not saying that we need to implement this approach now, but it is worth keeping in mind if you want a more efficient approach which doesn't require specifying a batch size and a timeout. We used it here https://github.com/genlm/genlm-bytes/blob/e76ca6908b2360690e5ecf098b377395b342978a/genlm/bytes/trie.py#L484.

@shepardxia shepardxia requested a review from benlebrun November 12, 2025 22:03
Copy link
Member

@benlebrun benlebrun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just left a few comments, then should be good to merge!


else:

class Query:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be using a data class here, e.g.,:

@DataClass
class Query:
prompt : str
future : asyncio.Future
past : Optional[mx.array] = None

self.generation_stream = mx.new_stream(mx.default_device())

self.queries = []
self.batch_size = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a warning to let the user know that the model is not batchable.

@staticmethod
def _to_torch(logprobs):
"""Convert MLX arrays into PyTorch tensors."""
if logprobs.dtype in [mx.bfloat16]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use is here

… to new API, with dependency updated to enforce lowest version with compatible.
… to new API, with dependency updated to enforce lowest version with compatible.
… to new API, with dependency updated to enforce lowest version with compatible.
@shepardxia shepardxia merged commit 73197e7 into main Nov 13, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants