Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
7f55a4a
initial commit
shepardxia Sep 12, 2025
32ab601
added naive logprobs and sampling methods
shepardxia Sep 15, 2025
8fb131c
add unit tests
shepardxia Sep 24, 2025
23fffb8
fix format and linter issues
shepardxia Oct 7, 2025
85333c3
fix format and linter issues
shepardxia Oct 7, 2025
b6d0e97
fix format and linter issues
shepardxia Oct 7, 2025
aec329b
fix format and linter issues
shepardxia Oct 7, 2025
5b1cd3f
fix format and linter issues
shepardxia Oct 7, 2025
8813df7
fix format and linter issues
shepardxia Oct 7, 2025
7ee172d
fixed after Ben's review
shepardxia Oct 8, 2025
23b0d16
fixing imports
shepardxia Oct 8, 2025
94d2163
separate coverage test job
shepardxia Oct 8, 2025
938c040
add benchmarking code and removed unnecessary torch <-> mlx conversions
shepardxia Oct 8, 2025
2d54f1e
revert output back to torch tensors
shepardxia Oct 8, 2025
9021fdc
cache subclassing to fix coverage
shepardxia Oct 9, 2025
fe0d412
cache subclassing to fix coverage
shepardxia Oct 9, 2025
4883201
cache subclassing to fix coverage
shepardxia Oct 9, 2025
8eafd77
merge with main updates
shepardxia Oct 17, 2025
fb81350
cov
shepardxia Oct 17, 2025
667d3ee
update test
shepardxia Oct 17, 2025
a34138c
initial commit
shepardxia Oct 23, 2025
cbef988
revising pytest params
shepardxia Oct 23, 2025
3b19338
prevent bf16 batching for now
shepardxia Oct 24, 2025
9444a81
adding coverage
shepardxia Oct 25, 2025
317b972
add additional tests for mlx
shepardxia Oct 27, 2025
a2c049c
modify test
shepardxia Oct 29, 2025
612b8ab
modify test
shepardxia Oct 29, 2025
da3f1d8
modify test
shepardxia Oct 29, 2025
e9625ab
modify test
shepardxia Oct 29, 2025
57b9e99
modify test
shepardxia Oct 29, 2025
3f017f0
Update with token trie kv cache
shepardxia Nov 12, 2025
ef7c256
Update with token trie kv cache, fixing tests
shepardxia Nov 12, 2025
dee7540
fixing tests
shepardxia Nov 12, 2025
0bc1d92
fixing tests
shepardxia Nov 12, 2025
8f5e575
fixing tests
shepardxia Nov 12, 2025
124c878
fixing tests
shepardxia Nov 12, 2025
eddfb94
fixing tests
shepardxia Nov 12, 2025
c904a94
final fix
shepardxia Nov 12, 2025
0703a9b
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia Nov 13, 2025
8299ba0
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia Nov 13, 2025
8b58cdf
Revised based on Ben's input. Updated HF cache construction to adhere…
shepardxia Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- name: Run MLX tests
run: |
source venv/bin/activate
coverage run --source=genlm/backend -m pytest tests/test_mlx.py --benchmark-disable
coverage run --source=genlm/backend -m pytest tests/test_mlx.py
coverage json --omit "*/test*"
coverage report --omit "*/test*"

Expand Down
21 changes: 0 additions & 21 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,3 @@ jobs:
pip install -e .[test]
pip install -r requirements-dev.txt
python -m pytest tests --ignore=tests/test_mlx.py

test-mlx:
runs-on: macos-14

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1

- uses: actions/setup-python@v4
with:
python-version: 3.11.5
cache: 'pip'

- name: Run MLX Tests
run: |
python -m venv venv
source venv/bin/activate
pip install -e .[mlx]
pip install -r requirements-dev.txt
python -m pytest tests/test_mlx.py
16 changes: 13 additions & 3 deletions benchmark/benchmark_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,29 @@
def load_model(model, batch_size=None):
model_name = "gpt2"
if model == "mlx":
return AsyncMlxLM.from_name(model_name)
return AsyncMlxLM.from_name(model_name, batch_size=batch_size)
else:
return AsyncTransformer.from_name(model_name, batch_size=batch_size)


@pytest.mark.parametrize("model", ["mlx", "transformer"])
@pytest.mark.parametrize(
"model",
[
"mlx",
],
)
def test_await_next_token_logprobs(benchmark, model):
llm = load_model(model, batch_size=1)
sequences = token_prefixes(text, tokenizer=llm.tokenizer)
run_await_next_token_logprobs(benchmark=benchmark, llm=llm, sequences=sequences)


@pytest.mark.parametrize("model", ["mlx", "transformer"])
@pytest.mark.parametrize(
"model",
[
"mlx",
],
)
def test_await_batch_next_token_logprobs(benchmark, model, batch_size=5):
llm = load_model(model, batch_size=batch_size)
batches = token_prefix_batches(text, tokenizer=llm.tokenizer, batch_size=batch_size)
Expand Down
252 changes: 215 additions & 37 deletions genlm/backend/llm/mlx.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import asyncio
from genlm.backend.llm.base import AsyncLM
from genlm.backend.cache import OutputMLXCache
from collections import defaultdict
import torch

from typing import (
Any,
Optional,
)


try:
import mlx_lm
from mlx_lm.generate import generate_step
from mlx_lm.generate import generate_step, BatchGenerator, wired_limit
import mlx.core as mx
from mlx_lm.models import cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
RotatingKVCache,
)

HAS_MLX = True
except ImportError: # pragma: no cover
Expand All @@ -39,17 +48,102 @@ def from_name(cls, *args, **kwargs): # pragma: no cover

else:

def _to_torch(logprobs):
"""Converts MLX array into torch tensors."""
if isinstance(logprobs, mx.array):
if logprobs.dtype in [mx.bfloat16]:
logprobs = logprobs.astype(mx.float32)
return torch.tensor(logprobs)
elif isinstance(logprobs, (list, tuple)):
return [_to_torch(lp) for lp in logprobs]
return logprobs

def _has_bf16(mlx_lm_model):
def check(x):
if isinstance(x, dict):
return any(check(v) for v in x.values())
elif isinstance(x, mx.array):
return getattr(x, "dtype", None) == mx.bfloat16
else:
return False

return any(
check(param)
for layer in mlx_lm_model.layers
for param in layer.parameters().values()
)

def _cache_batchable(mlx_lm_model):
if not hasattr(mlx_lm_model, "make_cache"):
return True

cache = mlx_lm_model.make_cache()
batchable = (CacheList, KVCache, ArraysCache)
return all(
isinstance(c, batchable) or (isinstance(c, RotatingKVCache) and c.keep == 0)
for c in cache
)

def _supports_batching(mlx_lm_model):
"""Return True only if MLX-LM has batching cache support for the model, and does not have bfloat16 parameters."""
return _cache_batchable(mlx_lm_model) and not _has_bf16(mlx_lm_model)

class BatchGeneratorCustom(BatchGenerator):
"""A custom batch generator optimzed for logprobs computation."""

def _next(self):
batch = self.active_batch
num_active = len(batch) if batch else 0
num_to_add = self.completion_batch_size - num_active
while num_to_add >= self.prefill_batch_size:
prompts = self.unprocessed_prompts[: self.prefill_batch_size]
# Finish processing the last examples of the last batch
if len(prompts) == 0 and num_active > 0:
break
batch = self._process_prompts(prompts)
self.unprocessed_prompts = self.unprocessed_prompts[
self.prefill_batch_size :
]
# If there was no active batch, set it
if self.active_batch is None:
self.active_batch = batch
else:
self.active_batch.extend(batch)

num_active = len(self.active_batch)
num_to_add -= len(batch)

batch = self.active_batch
y, logprobs = batch.y, batch.logprobs
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
mx.async_eval(batch.y, batch.logprobs)
return logprobs, batch

class Query:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be using a data class here, e.g.,:

@DataClass
class Query:
prompt : str
future : asyncio.Future
past : Optional[mx.array] = None

"""A query to a language model, waiting to be batched."""

def __init__(self, prompt, future):
self.prompt = prompt
self.future = future

class AsyncMlxLM(AsyncLM):
def __init__(self, mlx_lm_model, tokenizer, cache_size=0, cache_opts={}):
def __init__(
self,
mlx_lm_model,
tokenizer,
cache_size=0,
cache_opts={},
batch_size=5,
timeout=0.02,
**batch_opts,
):
"""Initialize an `AsyncMlxLM` instance.

Args:
mlx_lm_model (Model): The async MLX LM model instance.
cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
cache_opts (dict, optional): Additional options to pass to the [`OutputMLXCache`][genlm.backend.cache.OutputMLXCache] constructor. Defaults to {}.

"""

self.mlx_lm_model = mlx_lm_model
self.tokenizer = tokenizer
self.cache = (
Expand All @@ -58,6 +152,12 @@ def __init__(self, mlx_lm_model, tokenizer, cache_size=0, cache_opts={}):
else None
)
self.generation_stream = mx.new_stream(mx.default_device())
self.queries = []
self.batch_size = batch_size
self.timeout = timeout
self.timer = None
self.batching = _supports_batching(self.mlx_lm_model) and batch_size > 1
self.batch_opts = batch_opts

super().__init__(tokenizer=self.tokenizer)

Expand All @@ -80,6 +180,7 @@ def from_name(cls, model_name, **kwargs):

def clear_cache(self):
"""Clear output cache."""
mx.clear_cache()
if self.cache is not None:
self.cache.clear()

Expand Down Expand Up @@ -139,64 +240,141 @@ def _step(input_tokens: mx.array):
mx.async_eval(logprobs)
return logprobs

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.
def reset_async_queries(self):
"""Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
to completion."""
self.queries = []

def _batch_logits_custom(
self,
prompts,
):
"""
Compute next-token logits for each prompt in a batch using BatchGenerator.

Args:
token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.
model (nn.Module): The language model.
prompts (List[List[int]]): Each inner list is a prompt of token IDs.
verbose (bool): If True, prints progress info.
kwargs: Passed through to BatchGenerator.

Returns:
result (torch.Tensor): Normalized log probability tensor.
Tuple[List[mx.array], Stats]: A list of logits arrays (one per prompt),
and BatchGenerator statistics.
"""
gen = BatchGeneratorCustom(
self.mlx_lm_model, stop_tokens=[], **self.batch_opts
)
with wired_limit(self.mlx_lm_model, [self.generation_stream]):
_ = gen.insert(prompts, 1)
logprobs, batch = gen.next()
self.gen = batch
mx.clear_cache()
return logprobs

Warning:
Do not use `asyncio.run(next_token_logprobs())` as it may interfere with MLX's background loop.
For synchronous usage, use the `next_token_logprobs_sync()` method instead.
def batch_evaluate_queries(self):
"""
return self.next_token_logprobs_sync(token_ids)
Process a batch of queued language model queries.

def next_token_logprobs_sync(self, token_ids):
"""Request log probabilities of next token synchronously.
This method is called internally when the `batch_size` has been met or the `timeout` has expired.
"""

queries, self.queries = self.queries, []
if len(queries) == 0:
return

query_groups = defaultdict(list)
for query in queries:
key = tuple(query.prompt)
query_groups[key].append(query)

# Use one representative query from each group
unique_queries = [group[0] for group in query_groups.values()]

input_prompts = [q.prompt for q in unique_queries]
if self.batching:
results = self._batch_logits_custom(
input_prompts,
)
else:
results = [
self.next_token_logprobs_sync(q.prompt) for q in unique_queries
]

assert len(results) == len(unique_queries)

results = _to_torch(results)
for i, q in enumerate(unique_queries):
for dup_query in query_groups[tuple(q.prompt)]:
dup_query.future.set_result(results[i])

def add_query(self, query, future):
"""Add a query to be evaluated in the next batch.

This method is called internally when a `next_token_logprobs` request is made.

Args:
token_ids (list[int]): A list of token IDs, representing a prompt to the language model.
query (list[int]): Token IDs representing the query prompt
future (asyncio.Future): Future to store the result in
"""
self.queries.append(Query(query, future))

if self.timer:
self.timer.cancel()
self.timer = None
if len(self.queries) >= self.batch_size:
self.batch_evaluate_queries()
else:
self.timer = asyncio.get_running_loop().call_later(
self.timeout, lambda: self.batch_evaluate_queries()
)

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

Args:
token_ids (list[int]): a list of token ids, representing a prompt to the language model.

Returns:
(torch.Tensor): Normalized log probability tensor.
logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
"""
if not token_ids:
raise ValueError("Token ids must not be empty")

key = tuple(token_ids)

if self.cache is not None and key in self.cache:
return self.cache[key]

token_ids_array = mx.array(token_ids)
logprobs = self._generate_step_custom(token_ids_array)
logprobs = torch.tensor(logprobs)
future = asyncio.get_running_loop().create_future()
self.add_query(token_ids, future)
logprobs = await future
if self.cache is not None:
self.cache[key] = logprobs
return logprobs

async def batch_next_token_logprobs(self, token_ids_list):
"""
Request log probabilities of next tokens in a batch asynchronously.
Args:
token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.
Returns:
(torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
"""
return self.batch_next_token_logprobs_sync(token_ids_list)
def next_token_logprobs_sync(self, token_ids):
"""Request log probabilities of next token synchronously.

def batch_next_token_logprobs_sync(self, token_ids_list):
"""
Request log probabilities of next tokens in a batch synchronously.
Args:
token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.
token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

Returns:
(torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
(torch.Tensor): Normalized log probability tensor.
"""
outputs = []
for token_ids in token_ids_list:
outputs.append(self.next_token_logprobs_sync(token_ids))
return torch.stack(outputs)
if not token_ids:
raise ValueError("Token ids must not be empty")

key = tuple(token_ids)

if self.cache is not None and key in self.cache:
return self.cache[key]

token_ids_array = mx.array(token_ids)
logprobs = _to_torch(self._generate_step_custom(token_ids_array))
if self.cache is not None:
self.cache[key] = logprobs
return logprobs

async def sample(
self,
Expand Down
Loading