Skip to content

Commit f7a235d

Browse files
spokvulcanclaude
andauthored
perf: eliminate CPU←GPU sync in penalty processors, optimize TopPSampler (#147)
* perf: eliminate CPU←GPU sync in penalty processors for 35% faster generation All three penalty processors (RepetitionContext, PresencePenaltyContext, FrequencyPenaltyContext) called `token.item(Int.self)` in `didSample()`, forcing a CPU←GPU synchronization on every generated token. This broke the `asyncEval()` pipelining in `TokenIterator`, preventing the GPU from working ahead on the next token while the current one was being returned. Fix: Replace Swift `[Int]` token buffers with GPU-resident ring buffers (`MLXArray`). Update positions via `MLX.where` mask operations that stay entirely on GPU. No `.item()` or `.asArray()` calls remain in the hot path. Benchmarked on Qwen3.5-4B (248K vocab, topK=20, presencePenalty=1.5): - Peak tok/s: 70 → 95 (+35%) - Aggregate tok/s across 14 scenarios: 34.6 → 57.2 (+65%) The improvement comes from restoring async GPU pipelining — the GPU can now compute the next token's forward pass while the current token is being returned to the caller. * perf: optimize TopPSampler with argPartition for topK sampling When topK is set (e.g., 20) and much smaller than vocabulary size (e.g., 248K), use argPartition O(V) to find top-K candidates instead of argSort O(V log V) on the entire vocabulary. Then sort only the K candidates O(K log K) for cumulative probability filtering. Also switch from take() to takeAlong() throughout to preserve input dimensions and avoid squeeze/unsqueeze shape issues. * style: fix swift-format formatting in penalty processors * refactor: address PR review — align sampling with Python mlx-lm - Remove Sendable from LogitProcessor (MLXArray is not Sendable) - Extract TokenRing struct to deduplicate ring buffer across penalty processors - Use MLXArray.arange() (GPU primitive) computed once in init - Replace softmax with logSoftmax, work in log-prob space throughout - Apply filters in Python mlx-lm order: top_p → min_p → top_k - Each filter operates on full vocabulary matching mlx_lm.sample_utils - Use putAlong for O(V) scatter instead of double argSort in applyTopP * Mirror mlx-lm top-k masking semantics --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9a8e3a5 commit f7a235d

File tree

1 file changed

+133
-118
lines changed

1 file changed

+133
-118
lines changed

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 133 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import Foundation
44
import MLX
5+
import MLXNN
56
import Tokenizers
67

78
/// A `LogitSampler` is responsible for sampling `logits` produced by
@@ -31,7 +32,7 @@ public protocol LogitSampler {
3132
/// ```
3233
///
3334
/// See also: ``LogitSampler``
34-
public protocol LogitProcessor: Sendable {
35+
public protocol LogitProcessor {
3536

3637
/// called before token generation starts with the text tokens of the prompt
3738
mutating func prompt(_ prompt: MLXArray)
@@ -206,11 +207,17 @@ public struct ArgMaxSampler: LogitSampler {
206207

207208
/// Sampler that uses probability filters (`topP`, `topK`, `minP`) and `temperature`
208209
/// to sample the logits.
210+
///
211+
/// Filters are applied in the same order as Python mlx-lm: top_p → min_p → top_k.
212+
/// Each filter operates on the full vocabulary in original token order, masking
213+
/// rejected tokens with `-inf`. This matches the composable filter chain in
214+
/// `mlx_lm.sample_utils.make_sampler`.
209215
public struct TopPSampler: LogitSampler {
210216
let temp: MLXArray
211217
let topP: MLXArray?
212218
let topK: Int?
213219
let minP: MLXArray?
220+
let negInf: MLXArray
214221
let randomState: MLXRandom.RandomState
215222

216223
public init(temperature: Float, topP: Float = 1.0, topK: Int = 0, minP: Float = 0.0) {
@@ -222,6 +229,7 @@ public struct TopPSampler: LogitSampler {
222229
}
223230
self.topK = topK > 0 ? topK : nil
224231
self.minP = minP > 0 ? MLXArray(minP) : nil
232+
self.negInf = MLXArray(-Float.infinity)
225233
self.randomState = MLXRandom.RandomState()
226234
}
227235

@@ -232,46 +240,55 @@ public struct TopPSampler: LogitSampler {
232240
}
233241

234242
return withRandomState(randomState) {
235-
// Match mlx-lm Python behavior:
236-
// apply filtering on the base distribution, then apply temperature at sampling time.
237-
let probs = softmax(logits, axis: -1)
238-
let sortedIndices = argSort(probs, axis: -1)
239-
240-
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
241-
let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0)
242-
243-
var filteredProbs = sortedProbs
243+
var logprobs = logSoftmax(logits)
244244

245+
// Apply filters in Python mlx-lm order: top_p → min_p → top_k.
245246
if let topP {
246-
let cumulativeProbs = cumsum(sortedProbs, axis: -1)
247-
filteredProbs = MLX.where(
248-
cumulativeProbs .> (1 - topP), filteredProbs, zeros(like: filteredProbs))
247+
logprobs = applyTopP(logprobs, topP: topP)
249248
}
250-
251249
if let minP {
252-
let maxProbs = sortedProbs[0..., -1].expandedDimensions(axis: -1)
253-
let keepMask = sortedProbs .>= (maxProbs * minP)
254-
filteredProbs = MLX.where(keepMask, filteredProbs, zeros(like: filteredProbs))
250+
logprobs = applyMinP(logprobs, minP: minP)
255251
}
256-
257252
if let topK {
258-
let vocabularySize = sortedProbs.dim(-1)
259-
if topK < vocabularySize {
260-
let cutOff = vocabularySize - topK
261-
let sortedPositions = MLXArray(Array(0 ..< vocabularySize))
262-
let keepMask = sortedPositions .>= cutOff
263-
filteredProbs = MLX.where(
264-
keepMask, filteredProbs, zeros(like: filteredProbs))
265-
}
253+
logprobs = applyTopK(logprobs, topK: topK)
266254
}
267255

268-
// Always keep the maximum-probability token so sampling always has a valid candidate.
269-
filteredProbs[0..., -1] = sortedProbs[0..., -1]
270-
271-
let sortedToken = categorical(log(filteredProbs) * (1 / temp))
272-
return sortedIndices.squeezed(axis: 0)[sortedToken]
256+
return categorical(logprobs * (1 / temp))
273257
}
274258
}
259+
260+
/// Keep tokens whose cumulative probability exceeds `1 - topP` (nucleus sampling).
261+
/// Matches `apply_top_p` from `mlx_lm/sample_utils.py`.
262+
private func applyTopP(_ logprobs: MLXArray, topP: MLXArray) -> MLXArray {
263+
let sortedIndices = argSort(logprobs, axis: -1)
264+
let sortedLogprobs = takeAlong(logprobs, sortedIndices, axis: -1)
265+
let sortedProbs = exp(sortedLogprobs)
266+
let cumulativeProbs = cumsum(sortedProbs, axis: -1)
267+
268+
// Mask low-probability tail in sorted order, scatter back to original vocab order.
269+
let filtered = MLX.where(cumulativeProbs .> (1 - topP), sortedLogprobs, negInf)
270+
return putAlong(logprobs, sortedIndices, values: filtered, axis: -1)
271+
}
272+
273+
/// Keep tokens with probability >= maxProb * minP.
274+
/// Matches `apply_min_p` from `mlx_lm/sample_utils.py`.
275+
private func applyMinP(_ logprobs: MLXArray, minP: MLXArray) -> MLXArray {
276+
// threshold in log-space: log(maxProb * minP) = maxLogprob + log(minP)
277+
let maxLogprob = logprobs.max(axis: -1, keepDims: true)
278+
let threshold = maxLogprob + log(minP)
279+
return MLX.where(logprobs .>= threshold, logprobs, negInf)
280+
}
281+
282+
/// Keep only the top-k highest-probability tokens.
283+
/// Mirrors `apply_top_k` from `mlx_lm/sample_utils.py`.
284+
private func applyTopK(_ logprobs: MLXArray, topK: Int) -> MLXArray {
285+
let vocabularySize = logprobs.dim(-1)
286+
guard topK < vocabularySize else { return logprobs }
287+
// O(V) partition on negated logprobs so top-k land at [0, topK).
288+
// Indices at [topK, V) are the tokens to mask out.
289+
let maskIndices = argPartition(-logprobs, kth: topK - 1, axis: -1)[0..., topK...]
290+
return putAlong(logprobs, maskIndices, values: negInf, axis: -1)
291+
}
275292
}
276293

277294
/// Sampler that uses `temperature` to sample the logits.
@@ -291,151 +308,149 @@ public struct CategoricalSampler: LogitSampler {
291308
}
292309
}
293310

294-
/// Processor that implements a `repetitionPenalty`
295-
public struct RepetitionContext: LogitProcessor {
296-
/// tokens in the repetition context sliding window
297-
var tokens = [Int]()
311+
/// GPU-resident ring buffer of recent token IDs.
312+
///
313+
/// Shared by penalty processors to avoid duplicating ring buffer logic.
314+
/// Uses `MLX.where` mask operations for GPU-only updates (no CPU←GPU sync),
315+
/// preserving `asyncEval()` pipelining in `TokenIterator`.
316+
struct TokenRing {
317+
private(set) var buffer: MLXArray
318+
private(set) var count = 0
319+
private var writeIndex = 0
320+
let capacity: Int
321+
private let positions: MLXArray
322+
323+
init(capacity: Int) {
324+
precondition(capacity > 0)
325+
self.capacity = capacity
326+
self.buffer = MLXArray.zeros([capacity], type: Int32.self)
327+
self.positions = MLXArray.arange(capacity)
328+
}
298329

299-
/// current write index into the tokens circular array
300-
var index = 0
330+
/// The valid portion of the ring (all of it once full), or `nil` if empty.
331+
var validTokens: MLXArray? {
332+
guard count > 0 else { return nil }
333+
return count < capacity ? buffer[..<count] : buffer
334+
}
301335

302-
/// penalty factor for repeating tokens
303-
let repetitionPenalty: Float
336+
/// Bulk-load from a prompt. Keeps the last `capacity` tokens.
337+
mutating func loadPrompt(_ prompt: MLXArray) {
338+
let n = prompt.dim(0)
339+
let promptTokens = prompt.asType(.int32)
340+
if n <= capacity {
341+
if n < capacity {
342+
let padding = MLXArray.zeros([capacity - n], type: Int32.self)
343+
buffer = concatenated([promptTokens.reshaped(-1), padding])
344+
} else {
345+
buffer = promptTokens.reshaped(-1)
346+
}
347+
count = n
348+
writeIndex = n % capacity
349+
} else {
350+
buffer = promptTokens[(-capacity)...].reshaped(-1)
351+
count = capacity
352+
writeIndex = 0
353+
}
354+
}
304355

305-
/// number of tokens to consider for repetition penalty
306-
let repetitionContextSize: Int
356+
/// Append a single token using GPU-only mask write (no CPU←GPU sync).
357+
mutating func append(_ token: MLXArray) {
358+
let mask = positions .== Int32(writeIndex)
359+
buffer = MLX.where(mask, token.asType(.int32), buffer)
360+
writeIndex = (writeIndex + 1) % capacity
361+
count = min(count + 1, capacity)
362+
}
363+
}
364+
365+
/// Processor that implements a `repetitionPenalty`.
366+
public struct RepetitionContext: LogitProcessor {
367+
private var ring: TokenRing
368+
let repetitionPenalty: Float
307369

308370
public init(repetitionPenalty: Float, repetitionContextSize: Int) {
309-
precondition(repetitionContextSize > 0)
310371
self.repetitionPenalty = repetitionPenalty
311-
self.repetitionContextSize = repetitionContextSize
372+
self.ring = TokenRing(capacity: repetitionContextSize)
312373
}
313374

314375
mutating public func prompt(_ prompt: MLXArray) {
315-
if prompt.shape[0] <= repetitionContextSize {
316-
self.tokens = prompt.asArray(Int.self)
317-
} else {
318-
self.tokens = prompt[(-repetitionContextSize)...].asArray(Int.self)
319-
}
376+
ring.loadPrompt(prompt)
320377
}
321378

322379
public func process(logits: MLXArray) -> MLXArray {
323-
if tokens.count > 0 {
324-
let indices = MLXArray(tokens.map { UInt32($0) })
325-
var selectedLogits = logits[0..., indices]
380+
guard let indices = ring.validTokens?.asType(.uint32) else { return logits }
381+
var selectedLogits = logits[0..., indices]
326382

327-
selectedLogits = MLX.where(
328-
selectedLogits .< 0, selectedLogits * repetitionPenalty,
329-
selectedLogits / repetitionPenalty)
330-
331-
logits[0..., indices] = selectedLogits
332-
return logits
333-
}
383+
selectedLogits = MLX.where(
384+
selectedLogits .< 0, selectedLogits * repetitionPenalty,
385+
selectedLogits / repetitionPenalty)
334386

387+
logits[0..., indices] = selectedLogits
335388
return logits
336389
}
337390

338391
mutating public func didSample(token: MLXArray) {
339-
if tokens.count >= repetitionContextSize {
340-
tokens[index] = token.item(Int.self)
341-
index = (index + 1) % repetitionContextSize
342-
} else {
343-
tokens.append(token.item(Int.self))
344-
}
392+
ring.append(token)
345393
}
346394
}
347395

348396
/// Processor that applies an additive presence penalty to tokens in a recent context window.
397+
///
398+
/// The penalty is applied once per unique token via scatter-write (writing the
399+
/// same value to the same index multiple times is idempotent).
349400
public struct PresencePenaltyContext: LogitProcessor {
350-
var tokens = [Int]()
351-
var index = 0
352-
401+
private var ring: TokenRing
353402
let presencePenalty: Float
354-
let presenceContextSize: Int
355403

356404
public init(presencePenalty: Float, presenceContextSize: Int) {
357-
precondition(presenceContextSize > 0)
358405
self.presencePenalty = presencePenalty
359-
self.presenceContextSize = presenceContextSize
406+
self.ring = TokenRing(capacity: presenceContextSize)
360407
}
361408

362409
mutating public func prompt(_ prompt: MLXArray) {
363-
if prompt.shape[0] <= presenceContextSize {
364-
self.tokens = prompt.asArray(Int.self)
365-
} else {
366-
self.tokens = prompt[(-presenceContextSize)...].asArray(Int.self)
367-
}
410+
ring.loadPrompt(prompt)
368411
}
369412

370413
public func process(logits: MLXArray) -> MLXArray {
371-
if tokens.isEmpty {
372-
return logits
373-
}
374-
375-
let uniqueTokens = Array(Set(tokens))
376-
let indices = MLXArray(uniqueTokens.map { UInt32($0) })
414+
guard let indices = ring.validTokens?.asType(.uint32) else { return logits }
377415
logits[0..., indices] = logits[0..., indices] - presencePenalty
378416
return logits
379417
}
380418

381419
mutating public func didSample(token: MLXArray) {
382-
if tokens.count >= presenceContextSize {
383-
tokens[index] = token.item(Int.self)
384-
index = (index + 1) % presenceContextSize
385-
} else {
386-
tokens.append(token.item(Int.self))
387-
}
420+
ring.append(token)
388421
}
389422
}
390423

391424
/// Processor that applies an additive frequency penalty to tokens in a recent context window.
425+
///
426+
/// Frequency counting is performed on GPU via `scatter_add` to build a histogram
427+
/// of token occurrences, avoiding CPU←GPU synchronization.
392428
public struct FrequencyPenaltyContext: LogitProcessor {
393-
var tokens = [Int]()
394-
var index = 0
395-
429+
private var ring: TokenRing
396430
let frequencyPenalty: Float
397-
let frequencyContextSize: Int
398431

399432
public init(frequencyPenalty: Float, frequencyContextSize: Int) {
400-
precondition(frequencyContextSize > 0)
401433
self.frequencyPenalty = frequencyPenalty
402-
self.frequencyContextSize = frequencyContextSize
434+
self.ring = TokenRing(capacity: frequencyContextSize)
403435
}
404436

405437
mutating public func prompt(_ prompt: MLXArray) {
406-
if prompt.shape[0] <= frequencyContextSize {
407-
self.tokens = prompt.asArray(Int.self)
408-
} else {
409-
self.tokens = prompt[(-frequencyContextSize)...].asArray(Int.self)
410-
}
438+
ring.loadPrompt(prompt)
411439
}
412440

413441
public func process(logits: MLXArray) -> MLXArray {
414-
if tokens.isEmpty {
415-
return logits
416-
}
442+
guard let validTokens = ring.validTokens else { return logits }
417443

418-
var counts = [Int: Int]()
419-
for token in tokens {
420-
counts[token, default: 0] += 1
421-
}
444+
let vocabSize = logits.dim(-1)
445+
let ones = MLXArray.ones([validTokens.dim(0)], type: Float32.self)
446+
let histogram = MLXArray.zeros([vocabSize], type: Float32.self)
447+
.at[validTokens.asType(.int32)].add(ones)
422448

423-
let orderedTokens = Array(counts.keys)
424-
let indices = MLXArray(orderedTokens.map { UInt32($0) })
425-
let penalties = MLXArray(
426-
orderedTokens.map { frequencyPenalty * Float(counts[$0] ?? 0) }
427-
)
428-
logits[0..., indices] = logits[0..., indices] - penalties
429-
return logits
449+
return logits - (histogram * frequencyPenalty).reshaped(1, -1)
430450
}
431451

432452
mutating public func didSample(token: MLXArray) {
433-
if tokens.count >= frequencyContextSize {
434-
tokens[index] = token.item(Int.self)
435-
index = (index + 1) % frequencyContextSize
436-
} else {
437-
tokens.append(token.item(Int.self))
438-
}
453+
ring.append(token)
439454
}
440455
}
441456

0 commit comments

Comments
 (0)