Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Libraries/MLXEmbedders/Models/Qwen3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ private class Attention: Module {
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

// 4. Efficient Scaled Dot-Product Attention
let output = MLXFast.scaledDotProductAttention(
let output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
Expand Down
8 changes: 2 additions & 6 deletions Libraries/MLXLLM/Models/Apertus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -227,21 +227,17 @@ private class ApertusAttention: Module {
if let cache = cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)

// Update cache (expects [B, H, L, D])
let (k, v) = cache.update(keys: keys, values: values)
keys = k
values = v
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

// 5. Attention (SDPA expects [B, H, L, D])
let output = MLXFast.scaledDotProductAttention(
let output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
Expand Down
14 changes: 7 additions & 7 deletions Libraries/MLXLLM/Models/BaichuanM1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,20 @@ class BaichuanM1Attention: Module {
keys = rope(keys, offset: offset)

if let cache = cache as? CacheList {
let kvCache = cache[1]
let (cachedKeys, cachedValues) = kvCache.update(keys: keys, values: values)
keys = cachedKeys
values = cachedValues

if L > 0 {
let convCache = cache[0] as! MambaCache
convCache[0] = kInit[0..., 0..., (L - 1)..., 0...]
convCache[1] = vInit[0..., 0..., (L - 1)..., 0...]
}
}

let out = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
let out = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: (cache as? CacheList)?[1],
scale: scale,
mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
Expand Down
4 changes: 2 additions & 2 deletions Libraries/MLXLLM/Models/Bitnet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,16 @@ class BitnetAttention: Module {
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

let output = MLXFast.scaledDotProductAttention(
let output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
Expand Down
8 changes: 2 additions & 6 deletions Libraries/MLXLLM/Models/DeepseekV3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,15 @@ class DeepseekV3Attention: Module {

var (kNope, values) = (splitKv[0], splitKv[1])

var keys: MLXArray
if let cache = cache {
qPe = self.rope(qPe, offset: cache.offset)
kPe = self.rope(kPe, offset: cache.offset)
kPe = repeated(kPe, count: numHeads, axis: 1)
(keys, values) = cache.update(
keys: concatenated([kNope, kPe], axis: -1), values: values)
} else {
qPe = self.rope(qPe, offset: 0)
kPe = self.rope(kPe, offset: 0)
kPe = repeated(kPe, count: numHeads, axis: 1)
keys = concatenated([kNope, kPe], axis: -1)
}
kPe = repeated(kPe, count: numHeads, axis: 1)
let keys = concatenated([kNope, kPe], axis: -1)

let queries = concatenated([qNope, qPe], axis: -1)

Expand Down
8 changes: 5 additions & 3 deletions Libraries/MLXLLM/Models/FalconH1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,20 @@ class FalconH1Attention: Module {
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

var output = MLXFast.scaledDotProductAttention(
let attentionMask =
mask.map { MLXFast.ScaledDotProductAttentionMaskMode.array($0) } ?? .none
var output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
mask: attentionMask
)

output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1)
Expand Down
8 changes: 2 additions & 6 deletions Libraries/MLXLLM/Models/GLM4MOELite.swift
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,15 @@ class GLM4MoELiteAttention: Module {
var keys = concatenated([kvLatent, kPe], axis: -1)
var values = kvLatent // Values are the compressed KV latent

// Update cache with compressed representation
if let cache {
(keys, values) = cache.update(keys: keys, values: values)
}

// Create queries
let queries = concatenated([qNope, qPe], axis: -1)

// Compute attention
var output = MLXFast.scaledDotProductAttention(
var output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
Expand Down
39 changes: 7 additions & 32 deletions Libraries/MLXLLM/Models/GPTOSS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,48 +224,23 @@ class AttentionBlock: Module {
return active
}()

// Quantized cache path
if let qcache = cache as? QuantizedKVCacheProtocol {
if sinksActive {
fatalError("Quantized attention does not support non-zero sinks.")
}
if qcache.offset == 0 {
q = rope(q)
k = rope(k)
} else {
q = rope(q, offset: qcache.offset)
k = rope(k, offset: qcache.offset)
}

let (qKeys, qValues) = qcache.updateQuantized(keys: k, values: v)
let vHat = quantizedScaledDotProductAttention(
queries: q,
quantizedKeys: qKeys,
quantizedValues: qValues,
scale: smScale,
mask: mask,
groupSize: qcache.groupSize,
bits: qcache.bits,
mode: qcache.mode
)

return oProj(vHat.swappedAxes(1, 2).reshaped(B, L, -1))
}

if let cache {
q = rope(q, offset: cache.offset)
k = rope(k, offset: cache.offset)
(k, v) = cache.update(keys: k, values: v)
} else {
q = rope(q)
k = rope(k)
}

let vHat = MLXFast.scaledDotProductAttention(
queries: q, keys: k, values: v,
let vHat = attentionWithCacheUpdate(
queries: q,
keys: k,
values: v,
cache: cache,
scale: smScale,
mask: mask,
sinks: sinksActive ? sinks : nil)
sinks: sinksActive ? sinks : nil
)

return oProj(vHat.swappedAxes(1, 2).reshaped(B, L, -1))
}
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXLLM/Models/Gemma2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Gemma2Attention: Module {
if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
(keys, values) = updateCacheAndReturnArrays(keys: keys, values: values, cache: cache)
} else {
queries = rope(queries)
keys = rope(keys)
Expand Down
19 changes: 13 additions & 6 deletions Libraries/MLXLLM/Models/Gemma3nText.swift
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,9 @@ class Gemma3nAttention: Module {
var values: MLXArray

if isKvSharedLayer && cache != nil {
let state = cache!.state
if state.count >= 2 {
keys = state[0]
values = state[1]
if let state = dequantizedKVState(cache: cache!) {
keys = state.0
values = state.1
} else {
keys = kProj(x).reshaped(B, L, -1, headDim)
keys = kNorm(keys)
Expand All @@ -289,7 +288,11 @@ class Gemma3nAttention: Module {
values = values.transposed(0, 2, 1, 3)

if let cache = cache {
(keys, values) = cache.update(keys: keys, values: values)
(keys, values) = updateCacheAndReturnArrays(
keys: keys,
values: values,
cache: cache
)
}
}
} else {
Expand All @@ -303,7 +306,11 @@ class Gemma3nAttention: Module {
values = values.transposed(0, 2, 1, 3)

if let cache = cache {
(keys, values) = cache.update(keys: keys, values: values)
(keys, values) = updateCacheAndReturnArrays(
keys: keys,
values: values,
cache: cache
)
}
}

Expand Down
45 changes: 9 additions & 36 deletions Libraries/MLXLLM/Models/MiMoV2Flash.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,15 @@ private func attentionWithCacheUpdateAndSinks(
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none,
sinks: MLXArray? = nil
) -> MLXArray {
guard let cache else {
return MLXFast.scaledDotProductAttention(
queries: queries,
keys: keys,
values: values,
scale: scale,
mask: mask,
sinks: sinks
)
}

if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
precondition(sinks == nil, "Quantized SDPA does not support attention sinks.")
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
return quantizedScaledDotProductAttention(
queries: queries,
quantizedKeys: quantizedKeys,
quantizedValues: quantizedValues,
scale: scale,
mask: mask,
groupSize: quantizedKVCache.groupSize,
bits: quantizedKVCache.bits,
mode: quantizedKVCache.mode
)
} else {
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
return MLXFast.scaledDotProductAttention(
queries: queries,
keys: cachedKeys,
values: cachedValues,
scale: scale,
mask: mask,
sinks: sinks
)
}
attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask,
sinks: sinks
)
}

private func groupExpertSelect(
Expand Down
Loading
Loading