|
| 1 | +""" |
| 2 | +Evaluates performance differences between AsyncMlxLM (MLX-based) and AsyncTransformer |
| 3 | +(HuggingFace-based) implementations using pytest-benchmark. |
| 4 | +
|
| 5 | +pytest benchmark/benchmark_mlx.py --benchmark-only --benchmark-group-by=func |
| 6 | +""" |
| 7 | + |
| 8 | +import pytest |
| 9 | +from .util import ( |
| 10 | + get_wikitext, |
| 11 | + token_prefixes, |
| 12 | + token_prefix_batches, |
| 13 | + run_await_next_token_logprobs, |
| 14 | + run_await_batch_next_token_logprobs, |
| 15 | +) |
| 16 | + |
| 17 | +from genlm.backend.llm import AsyncMlxLM, AsyncTransformer |
| 18 | + |
| 19 | +text = get_wikitext() |
| 20 | + |
| 21 | + |
| 22 | +def load_model(model, batch_size=None): |
| 23 | + model_name = "gpt2" |
| 24 | + if model == "mlx": |
| 25 | + return AsyncMlxLM.from_name(model_name) |
| 26 | + else: |
| 27 | + return AsyncTransformer.from_name(model_name, batch_size=batch_size) |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.parametrize("model", ["mlx", "transformer"]) |
| 31 | +def test_await_next_token_logprobs(benchmark, model): |
| 32 | + llm = load_model(model, batch_size=1) |
| 33 | + sequences = token_prefixes(text, tokenizer=llm.tokenizer) |
| 34 | + run_await_next_token_logprobs(benchmark=benchmark, llm=llm, sequences=sequences) |
| 35 | + |
| 36 | + |
| 37 | +@pytest.mark.parametrize("model", ["mlx", "transformer"]) |
| 38 | +def test_await_batch_next_token_logprobs(benchmark, model, batch_size=5): |
| 39 | + llm = load_model(model, batch_size=batch_size) |
| 40 | + batches = token_prefix_batches(text, tokenizer=llm.tokenizer, batch_size=batch_size) |
| 41 | + run_await_batch_next_token_logprobs( |
| 42 | + benchmark=benchmark, llm=llm, batches=batches, rounds=50, warmup_rounds=10 |
| 43 | + ) |
0 commit comments