22
33import Foundation
44import MLX
5+ import MLXNN
56import Tokenizers
67
78/// A `LogitSampler` is responsible for sampling `logits` produced by
@@ -31,7 +32,7 @@ public protocol LogitSampler {
3132/// ```
3233///
3334/// See also: ``LogitSampler``
34- public protocol LogitProcessor : Sendable {
35+ public protocol LogitProcessor {
3536
3637 /// called before token generation starts with the text tokens of the prompt
3738 mutating func prompt( _ prompt: MLXArray )
@@ -206,11 +207,17 @@ public struct ArgMaxSampler: LogitSampler {
206207
207208/// Sampler that uses probability filters (`topP`, `topK`, `minP`) and `temperature`
208209/// to sample the logits.
210+ ///
211+ /// Filters are applied in the same order as Python mlx-lm: top_p → min_p → top_k.
212+ /// Each filter operates on the full vocabulary in original token order, masking
213+ /// rejected tokens with `-inf`. This matches the composable filter chain in
214+ /// `mlx_lm.sample_utils.make_sampler`.
209215public struct TopPSampler : LogitSampler {
210216 let temp : MLXArray
211217 let topP : MLXArray ?
212218 let topK : Int ?
213219 let minP : MLXArray ?
220+ let negInf : MLXArray
214221 let randomState : MLXRandom . RandomState
215222
216223 public init ( temperature: Float , topP: Float = 1.0 , topK: Int = 0 , minP: Float = 0.0 ) {
@@ -222,6 +229,7 @@ public struct TopPSampler: LogitSampler {
222229 }
223230 self . topK = topK > 0 ? topK : nil
224231 self . minP = minP > 0 ? MLXArray ( minP) : nil
232+ self . negInf = MLXArray ( - Float. infinity)
225233 self . randomState = MLXRandom . RandomState ( )
226234 }
227235
@@ -232,46 +240,55 @@ public struct TopPSampler: LogitSampler {
232240 }
233241
234242 return withRandomState ( randomState) {
235- // Match mlx-lm Python behavior:
236- // apply filtering on the base distribution, then apply temperature at sampling time.
237- let probs = softmax ( logits, axis: - 1 )
238- let sortedIndices = argSort ( probs, axis: - 1 )
239-
240- // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
241- let sortedProbs = take ( probs, sortedIndices, axis: - 1 ) . squeezed ( axis: 0 )
242-
243- var filteredProbs = sortedProbs
243+ var logprobs = logSoftmax ( logits)
244244
245+ // Apply filters in Python mlx-lm order: top_p → min_p → top_k.
245246 if let topP {
246- let cumulativeProbs = cumsum ( sortedProbs, axis: - 1 )
247- filteredProbs = MLX . where (
248- cumulativeProbs .> ( 1 - topP) , filteredProbs, zeros ( like: filteredProbs) )
247+ logprobs = applyTopP ( logprobs, topP: topP)
249248 }
250-
251249 if let minP {
252- let maxProbs = sortedProbs [ 0 ... , - 1 ] . expandedDimensions ( axis: - 1 )
253- let keepMask = sortedProbs .>= ( maxProbs * minP)
254- filteredProbs = MLX . where ( keepMask, filteredProbs, zeros ( like: filteredProbs) )
250+ logprobs = applyMinP ( logprobs, minP: minP)
255251 }
256-
257252 if let topK {
258- let vocabularySize = sortedProbs. dim ( - 1 )
259- if topK < vocabularySize {
260- let cutOff = vocabularySize - topK
261- let sortedPositions = MLXArray ( Array ( 0 ..< vocabularySize) )
262- let keepMask = sortedPositions .>= cutOff
263- filteredProbs = MLX . where (
264- keepMask, filteredProbs, zeros ( like: filteredProbs) )
265- }
253+ logprobs = applyTopK ( logprobs, topK: topK)
266254 }
267255
268- // Always keep the maximum-probability token so sampling always has a valid candidate.
269- filteredProbs [ 0 ... , - 1 ] = sortedProbs [ 0 ... , - 1 ]
270-
271- let sortedToken = categorical ( log ( filteredProbs) * ( 1 / temp) )
272- return sortedIndices. squeezed ( axis: 0 ) [ sortedToken]
256+ return categorical ( logprobs * ( 1 / temp) )
273257 }
274258 }
259+
260+ /// Keep tokens whose cumulative probability exceeds `1 - topP` (nucleus sampling).
261+ /// Matches `apply_top_p` from `mlx_lm/sample_utils.py`.
262+ private func applyTopP( _ logprobs: MLXArray , topP: MLXArray ) -> MLXArray {
263+ let sortedIndices = argSort ( logprobs, axis: - 1 )
264+ let sortedLogprobs = takeAlong ( logprobs, sortedIndices, axis: - 1 )
265+ let sortedProbs = exp ( sortedLogprobs)
266+ let cumulativeProbs = cumsum ( sortedProbs, axis: - 1 )
267+
268+ // Mask low-probability tail in sorted order, scatter back to original vocab order.
269+ let filtered = MLX . where ( cumulativeProbs .> ( 1 - topP) , sortedLogprobs, negInf)
270+ return putAlong ( logprobs, sortedIndices, values: filtered, axis: - 1 )
271+ }
272+
273+ /// Keep tokens with probability >= maxProb * minP.
274+ /// Matches `apply_min_p` from `mlx_lm/sample_utils.py`.
275+ private func applyMinP( _ logprobs: MLXArray , minP: MLXArray ) -> MLXArray {
276+ // threshold in log-space: log(maxProb * minP) = maxLogprob + log(minP)
277+ let maxLogprob = logprobs. max ( axis: - 1 , keepDims: true )
278+ let threshold = maxLogprob + log( minP)
279+ return MLX . where ( logprobs .>= threshold, logprobs, negInf)
280+ }
281+
282+ /// Keep only the top-k highest-probability tokens.
283+ /// Mirrors `apply_top_k` from `mlx_lm/sample_utils.py`.
284+ private func applyTopK( _ logprobs: MLXArray , topK: Int ) -> MLXArray {
285+ let vocabularySize = logprobs. dim ( - 1 )
286+ guard topK < vocabularySize else { return logprobs }
287+ // O(V) partition on negated logprobs so top-k land at [0, topK).
288+ // Indices at [topK, V) are the tokens to mask out.
289+ let maskIndices = argPartition ( - logprobs, kth: topK - 1 , axis: - 1 ) [ 0 ... , topK... ]
290+ return putAlong ( logprobs, maskIndices, values: negInf, axis: - 1 )
291+ }
275292}
276293
277294/// Sampler that uses `temperature` to sample the logits.
@@ -291,151 +308,149 @@ public struct CategoricalSampler: LogitSampler {
291308 }
292309}
293310
294- /// Processor that implements a `repetitionPenalty`
295- public struct RepetitionContext : LogitProcessor {
296- /// tokens in the repetition context sliding window
297- var tokens = [ Int] ( )
311+ /// GPU-resident ring buffer of recent token IDs.
312+ ///
313+ /// Shared by penalty processors to avoid duplicating ring buffer logic.
314+ /// Uses `MLX.where` mask operations for GPU-only updates (no CPU←GPU sync),
315+ /// preserving `asyncEval()` pipelining in `TokenIterator`.
316+ struct TokenRing {
317+ private( set) var buffer : MLXArray
318+ private( set) var count = 0
319+ private var writeIndex = 0
320+ let capacity : Int
321+ private let positions : MLXArray
322+
323+ init ( capacity: Int ) {
324+ precondition ( capacity > 0 )
325+ self . capacity = capacity
326+ self . buffer = MLXArray . zeros ( [ capacity] , type: Int32 . self)
327+ self . positions = MLXArray . arange ( capacity)
328+ }
298329
299- /// current write index into the tokens circular array
300- var index = 0
330+ /// The valid portion of the ring (all of it once full), or `nil` if empty.
331+ var validTokens : MLXArray ? {
332+ guard count > 0 else { return nil }
333+ return count < capacity ? buffer [ ..< count] : buffer
334+ }
301335
302- /// penalty factor for repeating tokens
303- let repetitionPenalty : Float
336+ /// Bulk-load from a prompt. Keeps the last `capacity` tokens.
337+ mutating func loadPrompt( _ prompt: MLXArray ) {
338+ let n = prompt. dim ( 0 )
339+ let promptTokens = prompt. asType ( . int32)
340+ if n <= capacity {
341+ if n < capacity {
342+ let padding = MLXArray . zeros ( [ capacity - n] , type: Int32 . self)
343+ buffer = concatenated ( [ promptTokens. reshaped ( - 1 ) , padding] )
344+ } else {
345+ buffer = promptTokens. reshaped ( - 1 )
346+ }
347+ count = n
348+ writeIndex = n % capacity
349+ } else {
350+ buffer = promptTokens [ ( - capacity) ... ] . reshaped ( - 1 )
351+ count = capacity
352+ writeIndex = 0
353+ }
354+ }
304355
305- /// number of tokens to consider for repetition penalty
306- let repetitionContextSize : Int
356+ /// Append a single token using GPU-only mask write (no CPU←GPU sync).
357+ mutating func append( _ token: MLXArray ) {
358+ let mask = positions .== Int32 ( writeIndex)
359+ buffer = MLX . where ( mask, token. asType ( . int32) , buffer)
360+ writeIndex = ( writeIndex + 1 ) % capacity
361+ count = min ( count + 1 , capacity)
362+ }
363+ }
364+
365+ /// Processor that implements a `repetitionPenalty`.
366+ public struct RepetitionContext : LogitProcessor {
367+ private var ring : TokenRing
368+ let repetitionPenalty : Float
307369
308370 public init ( repetitionPenalty: Float , repetitionContextSize: Int ) {
309- precondition ( repetitionContextSize > 0 )
310371 self . repetitionPenalty = repetitionPenalty
311- self . repetitionContextSize = repetitionContextSize
372+ self . ring = TokenRing ( capacity : repetitionContextSize)
312373 }
313374
314375 mutating public func prompt( _ prompt: MLXArray ) {
315- if prompt. shape [ 0 ] <= repetitionContextSize {
316- self . tokens = prompt. asArray ( Int . self)
317- } else {
318- self . tokens = prompt [ ( - repetitionContextSize) ... ] . asArray ( Int . self)
319- }
376+ ring. loadPrompt ( prompt)
320377 }
321378
322379 public func process( logits: MLXArray ) -> MLXArray {
323- if tokens. count > 0 {
324- let indices = MLXArray ( tokens. map { UInt32 ( $0) } )
325- var selectedLogits = logits [ 0 ... , indices]
380+ guard let indices = ring. validTokens? . asType ( . uint32) else { return logits }
381+ var selectedLogits = logits [ 0 ... , indices]
326382
327- selectedLogits = MLX . where (
328- selectedLogits .< 0 , selectedLogits * repetitionPenalty,
329- selectedLogits / repetitionPenalty)
330-
331- logits [ 0 ... , indices] = selectedLogits
332- return logits
333- }
383+ selectedLogits = MLX . where (
384+ selectedLogits .< 0 , selectedLogits * repetitionPenalty,
385+ selectedLogits / repetitionPenalty)
334386
387+ logits [ 0 ... , indices] = selectedLogits
335388 return logits
336389 }
337390
338391 mutating public func didSample( token: MLXArray ) {
339- if tokens. count >= repetitionContextSize {
340- tokens [ index] = token. item ( Int . self)
341- index = ( index + 1 ) % repetitionContextSize
342- } else {
343- tokens. append ( token. item ( Int . self) )
344- }
392+ ring. append ( token)
345393 }
346394}
347395
348396/// Processor that applies an additive presence penalty to tokens in a recent context window.
397+ ///
398+ /// The penalty is applied once per unique token via scatter-write (writing the
399+ /// same value to the same index multiple times is idempotent).
349400public struct PresencePenaltyContext : LogitProcessor {
350- var tokens = [ Int] ( )
351- var index = 0
352-
401+ private var ring : TokenRing
353402 let presencePenalty : Float
354- let presenceContextSize : Int
355403
356404 public init ( presencePenalty: Float , presenceContextSize: Int ) {
357- precondition ( presenceContextSize > 0 )
358405 self . presencePenalty = presencePenalty
359- self . presenceContextSize = presenceContextSize
406+ self . ring = TokenRing ( capacity : presenceContextSize)
360407 }
361408
362409 mutating public func prompt( _ prompt: MLXArray ) {
363- if prompt. shape [ 0 ] <= presenceContextSize {
364- self . tokens = prompt. asArray ( Int . self)
365- } else {
366- self . tokens = prompt [ ( - presenceContextSize) ... ] . asArray ( Int . self)
367- }
410+ ring. loadPrompt ( prompt)
368411 }
369412
370413 public func process( logits: MLXArray ) -> MLXArray {
371- if tokens. isEmpty {
372- return logits
373- }
374-
375- let uniqueTokens = Array ( Set ( tokens) )
376- let indices = MLXArray ( uniqueTokens. map { UInt32 ( $0) } )
414+ guard let indices = ring. validTokens? . asType ( . uint32) else { return logits }
377415 logits [ 0 ... , indices] = logits [ 0 ... , indices] - presencePenalty
378416 return logits
379417 }
380418
381419 mutating public func didSample( token: MLXArray ) {
382- if tokens. count >= presenceContextSize {
383- tokens [ index] = token. item ( Int . self)
384- index = ( index + 1 ) % presenceContextSize
385- } else {
386- tokens. append ( token. item ( Int . self) )
387- }
420+ ring. append ( token)
388421 }
389422}
390423
391424/// Processor that applies an additive frequency penalty to tokens in a recent context window.
425+ ///
426+ /// Frequency counting is performed on GPU via `scatter_add` to build a histogram
427+ /// of token occurrences, avoiding CPU←GPU synchronization.
392428public struct FrequencyPenaltyContext : LogitProcessor {
393- var tokens = [ Int] ( )
394- var index = 0
395-
429+ private var ring : TokenRing
396430 let frequencyPenalty : Float
397- let frequencyContextSize : Int
398431
399432 public init ( frequencyPenalty: Float , frequencyContextSize: Int ) {
400- precondition ( frequencyContextSize > 0 )
401433 self . frequencyPenalty = frequencyPenalty
402- self . frequencyContextSize = frequencyContextSize
434+ self . ring = TokenRing ( capacity : frequencyContextSize)
403435 }
404436
405437 mutating public func prompt( _ prompt: MLXArray ) {
406- if prompt. shape [ 0 ] <= frequencyContextSize {
407- self . tokens = prompt. asArray ( Int . self)
408- } else {
409- self . tokens = prompt [ ( - frequencyContextSize) ... ] . asArray ( Int . self)
410- }
438+ ring. loadPrompt ( prompt)
411439 }
412440
413441 public func process( logits: MLXArray ) -> MLXArray {
414- if tokens. isEmpty {
415- return logits
416- }
442+ guard let validTokens = ring. validTokens else { return logits }
417443
418- var counts = [ Int : Int ] ( )
419- for token in tokens {
420- counts [ token , default : 0 ] += 1
421- }
444+ let vocabSize = logits . dim ( - 1 )
445+ let ones = MLXArray . ones ( [ validTokens . dim ( 0 ) ] , type : Float32 . self )
446+ let histogram = MLXArray . zeros ( [ vocabSize ] , type : Float32 . self )
447+ . at [ validTokens . asType ( . int32 ) ] . add ( ones )
422448
423- let orderedTokens = Array ( counts. keys)
424- let indices = MLXArray ( orderedTokens. map { UInt32 ( $0) } )
425- let penalties = MLXArray (
426- orderedTokens. map { frequencyPenalty * Float( counts [ $0] ?? 0 ) }
427- )
428- logits [ 0 ... , indices] = logits [ 0 ... , indices] - penalties
429- return logits
449+ return logits - ( histogram * frequencyPenalty) . reshaped ( 1 , - 1 )
430450 }
431451
432452 mutating public func didSample( token: MLXArray ) {
433- if tokens. count >= frequencyContextSize {
434- tokens [ index] = token. item ( Int . self)
435- index = ( index + 1 ) % frequencyContextSize
436- } else {
437- tokens. append ( token. item ( Int . self) )
438- }
453+ ring. append ( token)
439454 }
440455}
441456
0 commit comments