Skip to content

Conversation

@awni
Copy link
Member

@awni awni commented May 7, 2025

Not sure it's worth merging this. The standalone benchmark is much improved but it's a very modest gain even for a very small model.

Adds a split logsumexp dispatch for when we get a really long vector.

x = mx.random.uniform(shape=(1, 4096 * 50))

def fun(x):
    for _ in range(100):
        x = x - mx.logsumexp(x, axis=-1, keepdims=True)
    return x

Pre: 234 ms
Post: 138 ms

Inference speed improved slightly for small Gemmas (which have a large vocab):

mlx_lm.generate --model mlx-community/gemma-3-1b-it-4bit --prompt "Write a story about Einstein" -m 512

Pre: 333.652 tokens-per-sec
Post: 334.837 tokens-per-sec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants