Skip to content

Commit 7ec0c05

Browse files
authored
Merge pull request #48 from genlm/mlx-lm
Add MLX-LM backend support for faster inference with Apple silicon. Parallel inference batching coming soon.
2 parents a457b68 + 1a97691 commit 7ec0c05

File tree

12 files changed

+586
-5
lines changed

12 files changed

+586
-5
lines changed

.github/workflows/coverage.yml

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,42 @@ jobs:
3232
- name: Run tests
3333
run: |
3434
source venv/bin/activate
35-
coverage run --source=genlm/backend -m pytest --benchmark-disable
35+
coverage run --source=genlm/backend -m pytest --benchmark-disable --ignore=tests/test_mlx.py
36+
coverage json --omit "*/test*"
37+
coverage report --omit "*/test*"
38+
39+
- name: Upload coverage to Codecov
40+
uses: codecov/codecov-action@v5
41+
with:
42+
fail_ci_if_error: false
43+
token: ${{ secrets.CODECOV_TOKEN }}
44+
files: ./coverage.json
45+
slug: genlm/genlm-backend
46+
47+
test_mlx_coverage:
48+
runs-on: macos-14
49+
50+
steps:
51+
- uses: actions/checkout@v4
52+
with:
53+
fetch-depth: 1
54+
55+
- uses: actions/setup-python@v4
56+
with:
57+
python-version: 3.11.5
58+
cache: 'pip'
59+
60+
- name: Install dependencies
61+
run: |
62+
python -m venv venv
63+
source venv/bin/activate
64+
pip install -e .[mlx]
65+
pip install -r requirements-dev.txt
66+
67+
- name: Run MLX tests
68+
run: |
69+
source venv/bin/activate
70+
coverage run --source=genlm/backend -m pytest tests/test_mlx.py --benchmark-disable
3671
coverage json --omit "*/test*"
3772
coverage report --omit "*/test*"
3873

.github/workflows/pytest.yml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,25 @@ jobs:
2929
source venv/bin/activate
3030
pip install -e .[test]
3131
pip install -r requirements-dev.txt
32-
python -m pytest tests
32+
python -m pytest tests --ignore=tests/test_mlx.py
33+
34+
test-mlx:
35+
runs-on: macos-14
36+
37+
steps:
38+
- uses: actions/checkout@v4
39+
with:
40+
fetch-depth: 1
41+
42+
- uses: actions/setup-python@v4
43+
with:
44+
python-version: 3.11.5
45+
cache: 'pip'
46+
47+
- name: Run MLX Tests
48+
run: |
49+
python -m venv venv
50+
source venv/bin/activate
51+
pip install -e .[mlx]
52+
pip install -r requirements-dev.txt
53+
python -m pytest tests/test_mlx.py

DEVELOPING.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ uv pip install -e ".[docs]"
2727
uv pip install -r requirements-dev.txt
2828
```
2929

30+
To build with MLX support, run:
31+
```bash
32+
uv pip install -e ".[mlx]"
33+
```
34+
3035
## Testing
3136

3237
When test dependencies are installed, the test suite can be run via:

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ See our [documentation](https://genlm.github.io/genlm-backend/).
1818
- Automatic batching of concurrent log-probability requests, enabling efficient large-scale inference without having to write batching logic yourself
1919
- Byte-level decoding of transformers tokenizers, enabling advanced token-level control
2020
- Support for arbitrary Hugging Face models (e.g., LLaMA, DeepSeek, etc.) with fast inference and automatic KV caching using vllm
21+
- NEW: support for MLX-LM library, allowing faster inference on Apple silicon devices.
2122

2223

2324
## ⚡ Quick Start
@@ -28,6 +29,13 @@ This library supports installation via pip:
2829
pip install genlm-backend
2930
```
3031

32+
Or to install with MLX support, run:
33+
34+
```bash
35+
pip install genlm-backend[mlx]
36+
```
37+
38+
3139
## 🧪 Example: Autobatched Sequential Importance Sampling with LLMs
3240

3341
This example demonstrates how `genlm-backend` enables concise, scalable probabilistic inference with language models. It implements a Sequential Importance Sampling (SIS) algorithm that makes asynchronous log-probabality requests which get automatically batched by the language model.

benchmark/benchmark_mlx.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Evaluates performance differences between AsyncMlxLM (MLX-based) and AsyncTransformer
3+
(HuggingFace-based) implementations using pytest-benchmark.
4+
5+
pytest benchmark/benchmark_mlx.py --benchmark-only --benchmark-group-by=func
6+
"""
7+
8+
import pytest
9+
from .util import (
10+
get_wikitext,
11+
token_prefixes,
12+
token_prefix_batches,
13+
run_await_next_token_logprobs,
14+
run_await_batch_next_token_logprobs,
15+
)
16+
17+
from genlm.backend.llm import AsyncMlxLM, AsyncTransformer
18+
19+
text = get_wikitext()
20+
21+
22+
def load_model(model, batch_size=None):
23+
model_name = "gpt2"
24+
if model == "mlx":
25+
return AsyncMlxLM.from_name(model_name)
26+
else:
27+
return AsyncTransformer.from_name(model_name, batch_size=batch_size)
28+
29+
30+
@pytest.mark.parametrize("model", ["mlx", "transformer"])
31+
def test_await_next_token_logprobs(benchmark, model):
32+
llm = load_model(model, batch_size=1)
33+
sequences = token_prefixes(text, tokenizer=llm.tokenizer)
34+
run_await_next_token_logprobs(benchmark=benchmark, llm=llm, sequences=sequences)
35+
36+
37+
@pytest.mark.parametrize("model", ["mlx", "transformer"])
38+
def test_await_batch_next_token_logprobs(benchmark, model, batch_size=5):
39+
llm = load_model(model, batch_size=batch_size)
40+
batches = token_prefix_batches(text, tokenizer=llm.tokenizer, batch_size=batch_size)
41+
run_await_batch_next_token_logprobs(
42+
benchmark=benchmark, llm=llm, batches=batches, rounds=50, warmup_rounds=10
43+
)

genlm/backend/cache.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,33 @@ def clear(self):
4343
self.cache.clear()
4444

4545

46+
class OutputMLXCache(OutputCache):
47+
"""A cache for storing tensor outputs with MLX.
48+
49+
Since MLX uses unified memory, we don't need to move tensors between CPU and GPU.
50+
51+
Args:
52+
maxsize (int): Maximum number of items to store in the cache
53+
"""
54+
55+
def __init__(self, maxsize):
56+
super().__init__(maxsize, move_to_cpu=False)
57+
58+
def __getitem__(self, key):
59+
if key in self.cache:
60+
value = self.cache.pop(key)
61+
self.cache[key] = value
62+
return value
63+
raise KeyError(key)
64+
65+
def __setitem__(self, key, value):
66+
if len(self.cache) >= self.maxsize:
67+
_, old_tensor = self.cache.popitem(last=False)
68+
del old_tensor
69+
70+
self.cache[key] = value
71+
72+
4673
class TokenTrie:
4774
"""Class used internally to cache language model results.
4875

genlm/backend/llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from genlm.backend.llm.vllm import AsyncVirtualLM
22
from genlm.backend.llm.hf import AsyncTransformer
33
from genlm.backend.llm.base import AsyncLM, MockAsyncLM
4+
from genlm.backend.llm.mlx import AsyncMlxLM
45

56
import torch
67

@@ -33,6 +34,8 @@ def load_model_by_name(name, backend=None, llm_opts=None):
3334
return AsyncTransformer.from_name(name, **llm_opts)
3435
elif backend == "mock":
3536
return MockAsyncLM.from_name(name, **llm_opts)
37+
elif backend == "mlx":
38+
return AsyncMlxLM.from_name(name, **llm_opts)
3639
else:
3740
raise ValueError(f"Invalid backend: {backend}")
3841

@@ -42,5 +45,6 @@ def load_model_by_name(name, backend=None, llm_opts=None):
4245
"AsyncLM",
4346
"AsyncVirtualLM",
4447
"AsyncTransformer",
48+
"AsyncMlxLM",
4549
"MockAsyncLM",
4650
]

0 commit comments

Comments
 (0)