Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
0464b2e
Add mission infrastructure for continuous batching
ronaldmannak Mar 14, 2026
a4fda6b
Implement BatchKVCache with left-padding strategy for continuous batc…
ronaldmannak Mar 14, 2026
81cc226
Add batch-aware masking, BatchPositionedKVCache protocol, and applyRo…
ronaldmannak Mar 14, 2026
917b015
Implement BatchRotatingKVCache for sliding-window batch attention
ronaldmannak Mar 14, 2026
8c2f0a2
Add Metal availability guard to skip MLX-dependent tests in SPM builds
ronaldmannak Mar 14, 2026
bb5c180
Fix swift-format lint violations in batch files
ronaldmannak Mar 14, 2026
d77658c
Record batch-kv-cache scrutiny findings
ronaldmannak Mar 14, 2026
6c2ff25
Update worker skill with MLX Metal TDD guidance
ronaldmannak Mar 14, 2026
6156819
Fix BatchKVCache state serialization, makeMask offset timing, and Sen…
ronaldmannak Mar 14, 2026
e1b91a8
Add prepare/finalize for cached-prompt prefill and preserve keep in B…
ronaldmannak Mar 14, 2026
951c08c
Record batch-kv-cache scrutiny rerun findings
ronaldmannak Mar 14, 2026
ab9bd83
Fix BatchRotatingKVCache to preserve keep prefix tokens during slidin…
ronaldmannak Mar 14, 2026
f24d18a
Record batch-kv-cache keep-semantics scrutiny rerun
ronaldmannak Mar 14, 2026
2933b3a
Fix BatchRotatingKVCache.extract() negative leftPadding after overflow
ronaldmannak Mar 14, 2026
1f9748d
Record batch-kv-cache scrutiny pass after extract fix
ronaldmannak Mar 14, 2026
dd5364b
Record batch-kv-cache user-testing findings
ronaldmannak Mar 14, 2026
d90b54b
Fix BatchKVCache.makeMask() key-width to equal _idx
ronaldmannak Mar 14, 2026
6c5a890
Fix BatchRotatingKVCache keep-prefix corruption for padded sequences …
ronaldmannak Mar 14, 2026
0e91d2a
Record batch-kv-cache user-testing rerun pass
ronaldmannak Mar 14, 2026
4045514
Implement BatchTokenIterator core batch generation engine
ronaldmannak Mar 14, 2026
72ef687
Add per-request sampler/processor support and correctness tests for B…
ronaldmannak Mar 14, 2026
d8877b1
Record batch-engine scrutiny findings
ronaldmannak Mar 14, 2026
ee437f6
Fix batch admission scheduling and add concurrency safety to BatchTok…
ronaldmannak Mar 14, 2026
3aea647
Record batch-engine scrutiny rerun pass
ronaldmannak Mar 14, 2026
35d83e9
Record batch-engine user-testing findings
ronaldmannak Mar 14, 2026
e5cd48c
Fix per-request sampler concatenate crash in BatchTokenIterator
ronaldmannak Mar 14, 2026
2831491
Record batch-engine user-testing rerun pass
ronaldmannak Mar 14, 2026
138c89e
Implement InferenceScheduler actor with single-first upgrade strategy
ronaldmannak Mar 14, 2026
2c327de
Integrate InferenceScheduler with ModelContainer for transparent batc…
ronaldmannak Mar 14, 2026
30875a4
Record scheduler scrutiny findings
ronaldmannak Mar 14, 2026
4949457
Fix scheduler upgrade stream continuity, KV cache migration, and Chat…
ronaldmannak Mar 14, 2026
5f4244b
Record scheduler scrutiny rerun findings
ronaldmannak Mar 14, 2026
c2731b3
Fix scheduler upgrade to use live TokenIterator state and rebind canc…
ronaldmannak Mar 14, 2026
c5d6d87
Record scheduler scrutiny rerun findings
ronaldmannak Mar 14, 2026
6d8fea6
Fix scheduler upgrade tensor shape, token boundary drop, and actor re…
ronaldmannak Mar 14, 2026
ea6a496
Record scheduler scrutiny rerun findings
ronaldmannak Mar 14, 2026
214907d
Fix maxTokens off-by-one in upgradeToBatch() and Sendable warnings
ronaldmannak Mar 14, 2026
0d82970
Record scheduler scrutiny rerun findings
ronaldmannak Mar 14, 2026
6cb705d
Override scheduler scrutiny: all xcodebuild tests pass, code correct
ronaldmannak Mar 14, 2026
c3c7d7d
Record scheduler user-testing findings
ronaldmannak Mar 14, 2026
1f790fa
Defer 6 scheduler assertions to cross-area, add Sendable fix feature,…
ronaldmannak Mar 14, 2026
72b1d9d
Fix Sendable/data-race warnings in scheduler integration
ronaldmannak Mar 14, 2026
b42bea2
Implement LRUPromptCache with trie-based prefix matching and LRU evic…
ronaldmannak Mar 14, 2026
543dcea
Integrate LRUPromptCache with batch generation for cached prompt prefill
ronaldmannak Mar 14, 2026
dc95c98
Record prompt-cache scrutiny findings
ronaldmannak Mar 14, 2026
f3fb872
Fix 4 correctness bugs in LRUPromptCache with regression tests
ronaldmannak Mar 14, 2026
fd76aa9
Fix 2 correctness bugs in prompt-cache batch integration
ronaldmannak Mar 14, 2026
67621ff
Record prompt-cache scrutiny rerun findings
ronaldmannak Mar 14, 2026
ba84c09
Fix mixed-depth cached-prefill holes and RotatingKVCache support in b…
ronaldmannak Mar 14, 2026
7784fc7
Record prompt-cache scrutiny round 3 findings
ronaldmannak Mar 14, 2026
5003535
Fix mixed-depth cached-prefill with prepare/finalize lifecycle
ronaldmannak Mar 14, 2026
5caf3b6
Record prompt-cache scrutiny round 4 findings
ronaldmannak Mar 14, 2026
c4b7e60
Record prompt-cache user-testing findings
ronaldmannak Mar 14, 2026
6f2ec9c
Fix trimPromptCache to trim all layers and correct exact-hit test exp…
ronaldmannak Mar 14, 2026
76eb7fb
Record prompt-cache user-testing rerun findings
ronaldmannak Mar 14, 2026
1ec0540
Migrate all MLXLLM model RoPE calls to applyRotaryPosition
ronaldmannak Mar 14, 2026
d1ecb3c
Add comprehensive cross-area integration tests for batching
ronaldmannak Mar 14, 2026
cbf660f
Record example-app scrutiny findings
ronaldmannak Mar 14, 2026
81f9048
Strengthen cross-area integration test assertions and fix compile war…
ronaldmannak Mar 14, 2026
09a3f1b
Record example-app scrutiny rerun findings
ronaldmannak Mar 14, 2026
dc5b4e2
Override example-app scrutiny: all tests pass, dead code paths dismissed
ronaldmannak Mar 14, 2026
6302688
Record example-app user-testing findings
ronaldmannak Mar 14, 2026
8d94311
Record example-app user-testing rerun findings
ronaldmannak Mar 14, 2026
79716d6
Complete example-app milestone: all 86 assertions passed
ronaldmannak Mar 14, 2026
202e74a
Add continuous batching section to README
ronaldmannak Mar 14, 2026
4653d69
Fix 3rd+ requests missing streaming events in batch mode
ronaldmannak Mar 15, 2026
7e42f13
Fix rotating/sliding-window caches silently dropped during batch crea…
ronaldmannak Mar 15, 2026
0544fab
Fix batched .info events: report correct promptTokenCount and preserv…
ronaldmannak Mar 15, 2026
d687b55
Wire LRUPromptCache into scheduler path for upstream parity
ronaldmannak Mar 15, 2026
1914be7
Record post-review scrutiny findings
ronaldmannak Mar 15, 2026
9f25ef8
Fix vacuous rotating cache preservation test with actual content veri…
ronaldmannak Mar 15, 2026
ed1a06d
Fix incorrect promptTime for 3rd+ requests joining existing batch
ronaldmannak Mar 15, 2026
0e321ef
Fix prompt cache wiring completeness: write-back, single-path cache u…
ronaldmannak Mar 15, 2026
f9db3fb
Record post-review scrutiny rerun findings
ronaldmannak Mar 15, 2026
5ae3e88
Make testUpgradePreservesRotatingKVCacheState deterministic
ronaldmannak Mar 15, 2026
af4171d
Fix prompt cache write-back to use full token sequence key
ronaldmannak Mar 15, 2026
eb137ec
Fix flaky testUpgradePreservesRotatingKVCacheState timing
ronaldmannak Mar 15, 2026
c1ed95a
Record post-review scrutiny round 3 findings
ronaldmannak Mar 15, 2026
27ec16a
Fix RotatingCacheMockModel to never produce EOS token 0
ronaldmannak Mar 15, 2026
ac076d7
Fix prompt cache write-back to include pre-upgrade generated tokens
ronaldmannak Mar 15, 2026
e524983
Record post-review scrutiny round 4 findings
ronaldmannak Mar 15, 2026
83f37d3
Record post-review user testing results
ronaldmannak Mar 15, 2026
5603b7e
Fix mixed-layer cached partial-hit to use per-layer type check
ronaldmannak Mar 15, 2026
b8c389a
Fix BatchKVCache masks for post-update attention width
ronaldmannak Mar 15, 2026
83bbd80
Fix mixed-depth cached-prefill final cache extraction
ronaldmannak Mar 15, 2026
21a2e85
Record post-review-followup scrutiny findings
ronaldmannak Mar 15, 2026
a8a06a5
Record post-review-followup user testing results
ronaldmannak Mar 15, 2026
f2cb539
Fix scheduler fallback prompt-cache propagation
ronaldmannak Mar 16, 2026
3d5efee
Record post-review-followup-2 scrutiny findings
ronaldmannak Mar 16, 2026
7034d8b
Record post-review-followup-2 user testing results
ronaldmannak Mar 16, 2026
f349916
Update gitignore, remove .factory from git repo
ronaldmannak Mar 16, 2026
9063a11
Add batching support for NanoChat and Phi3
ronaldmannak Mar 16, 2026
c6b6a8c
Batch-aware attention masks for FalconH1 and Gemma
ronaldmannak Mar 16, 2026
6ed9754
swift lint
ronaldmannak Mar 17, 2026
46a3c18
Add wired memory support
ronaldmannak Mar 17, 2026
691b14d
swift lint
ronaldmannak Mar 17, 2026
424c6ab
improve dual path routing
ronaldmannak Mar 20, 2026
7f25571
Add raw token batching
ronaldmannak Mar 20, 2026
1160f8a
Add raw token batching
ronaldmannak Mar 20, 2026
27fc9b9
Update SKILL.md
ronaldmannak Mar 20, 2026
42cdda7
Revert order Model Factories (VLM is first again)
ronaldmannak Mar 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,5 @@ iOSInjectionProject/

.idea
.vscode

.claude/
.factory/
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/AfMoE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,8 @@ class AfMoEAttention: Module {

// Apply RoPE only for local (sliding window) attention
if isLocalAttention, let rope = rope {
if let cache = cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)
}

var output = attentionWithCacheUpdate(
Expand Down
9 changes: 3 additions & 6 deletions Libraries/MLXLLM/Models/Apertus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,14 @@ private class ApertusAttention: Module {
values = values.transposed(0, 2, 1, 3)

// 4. RoPE
if let cache = cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

if let cache = cache {
// Update cache (expects [B, H, L, D])
let (k, v) = cache.update(keys: keys, values: values)
keys = k
values = v
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

// 5. Attention (SDPA expects [B, H, L, D])
Expand Down
7 changes: 3 additions & 4 deletions Libraries/MLXLLM/Models/BaichuanM1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,11 @@ class BaichuanM1Attention: Module {
var keys = qkv[1].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)
var values = qkv[2].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)

var offset = 0
var lastK: MLXArray? = nil
var lastV: MLXArray? = nil
let kvSubCache: KVCache? = (cache as? CacheList)?[1]

if let cacheList = cache as? CacheList {
offset = cacheList[1].offset
if let mambaCache = cacheList[0] as? MambaCache {
lastK = mambaCache[0]
lastV = mambaCache[1]
Expand All @@ -131,8 +130,8 @@ class BaichuanM1Attention: Module {
keys = customConvolution(keys, convK, state: lastK)
values = customConvolution(values, convV, state: lastV)

queries = rope(queries, offset: offset)
keys = rope(keys, offset: offset)
queries = applyRotaryPosition(rope, to: queries, cache: kvSubCache)
keys = applyRotaryPosition(rope, to: keys, cache: kvSubCache)

if let cache = cache as? CacheList {
let kvCache = cache[1]
Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/BailingMoe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,8 @@ class BailingMoeAttention: Module {
keys = keys.transposed(0, 2, 1, 3)
values = values.reshaped(B, L, kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
8 changes: 3 additions & 5 deletions Libraries/MLXLLM/Models/Bitnet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,11 @@ class BitnetAttention: Module {
keys = keys.reshaped(B, L, args.resolvedKvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.resolvedKvHeads, -1).transposed(0, 2, 1, 3)

queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}

let output = MLXFast.scaledDotProductAttention(
Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/Cohere.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ class CohereAttention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries)
keys = rope(keys)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
10 changes: 4 additions & 6 deletions Libraries/MLXLLM/Models/DeepseekV3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,15 @@ class DeepseekV3Attention: Module {

var (kNope, values) = (splitKv[0], splitKv[1])

qPe = applyRotaryPosition(self.rope, to: qPe, cache: cache)
kPe = applyRotaryPosition(self.rope, to: kPe, cache: cache)
kPe = repeated(kPe, count: numHeads, axis: 1)

var keys: MLXArray
if let cache = cache {
qPe = self.rope(qPe, offset: cache.offset)
kPe = self.rope(kPe, offset: cache.offset)
kPe = repeated(kPe, count: numHeads, axis: 1)
(keys, values) = cache.update(
keys: concatenated([kNope, kPe], axis: -1), values: values)
} else {
qPe = self.rope(qPe, offset: 0)
kPe = self.rope(kPe, offset: 0)
kPe = repeated(kPe, count: numHeads, axis: 1)
keys = concatenated([kNope, kPe], axis: -1)
}

Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/Ernie4_5.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,8 @@ class Ernie45Attention: Module {
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
9 changes: 3 additions & 6 deletions Libraries/MLXLLM/Models/Exaone4.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,9 @@ class Exaone4Attention: Module {
keys = kNorm(keys.reshaped(B, L, args.kvHeads, -1)).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache, useRope, let rope {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else if useRope, let rope {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
if useRope, let rope {
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)
}

let output = attentionWithCacheUpdate(
Expand Down
48 changes: 23 additions & 25 deletions Libraries/MLXLLM/Models/FalconH1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,11 @@ class FalconH1Attention: Module {
maxPositionEmbeddings: args.maxPositionEmbeddings)
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? = nil) -> MLXArray {
func callAsFunction(
_ x: MLXArray,
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none,
cache: KVCache? = nil
) -> MLXArray {
let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2))

var queries = qProj(x)
Expand All @@ -302,19 +306,14 @@ class FalconH1Attention: Module {
keys = keys.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
(keys, values) = cache.update(keys: keys, values: values)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

var output = MLXFast.scaledDotProductAttention(
var output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
Expand Down Expand Up @@ -578,7 +577,7 @@ class FalconH1DecoderLayer: Module {
func callAsFunction(
_ h: MLXArray,
cache: CacheList?,
attnMask: MLXArray?,
attnMask: MLXFast.ScaledDotProductAttentionMaskMode,
mambaMask: MLXArray?
) -> MLXArray {
var residual = h
Expand Down Expand Up @@ -610,17 +609,6 @@ private func createSSMMask(h: MLXArray, cache: ArraysCache?) -> MLXArray? {
return nil
}

private func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? {
let N = h.dim(1)
// If cache exists and can make masks, use it
// Otherwise for single token, no mask needed
// For multi-token, SDPA will handle causal mask internally when nil
if N == 1 {
return nil
}
return nil // Will be handled by SDPA internally when nil
}

// MARK: - Model

public class FalconH1ModelInner: Module {
Expand Down Expand Up @@ -649,16 +637,26 @@ public class FalconH1ModelInner: Module {
_finalLayerNorm.wrappedValue = RMSNorm(dimensions: hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil, cache: [CacheList]? = nil)
func callAsFunction(
_ inputs: MLXArray,
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none,
cache: [CacheList]? = nil
)
-> MLXArray
{
var h = embedTokens(inputs)

let cache: [CacheList?] = cache ?? Array(repeating: nil, count: layers.count)

let mambaMask = createSSMMask(h: h, cache: cache[0]?[0] as? MambaCache)
let attnMask: MLXArray? = createAttentionMask(
h: h, cache: cache[0]?[1] != nil ? [cache[0]![1]] : nil)
let attnMask: MLXFast.ScaledDotProductAttentionMaskMode = {
switch mask {
case .none:
return createAttentionMask(h: h, cache: cache[0]?[1])
default:
return mask
}
}()

for (layer, c) in zip(layers, cache) {
h = layer(
Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/GLM4.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,8 @@ class GLM4Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/GLM4MOE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,8 @@ class GLM4MoEAttention: Module {
keys = keys.transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries, offset: 0)
keys = rope(keys, offset: 0)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
5 changes: 2 additions & 3 deletions Libraries/MLXLLM/Models/GLM4MOELite.swift
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,8 @@ class GLM4MoELiteAttention: Module {
kPe = kPe.reshaped(B, L, 1, qkRopeHeadDim).transposed(0, 2, 1, 3)
var kvLatent = kvALayerNorm(compressedKv)

let offset = cache?.offset ?? 0
qPe = rope(qPe, offset: offset)
kPe = rope(kPe, offset: offset)
qPe = applyRotaryPosition(rope, to: qPe, cache: cache)
kPe = applyRotaryPosition(rope, to: kPe, cache: cache)

// Expand kvLatent for attention: [B, L, kvLoraRank] -> [B, 1, L, kvLoraRank]
kvLatent = expandedDimensions(kvLatent, axis: 1)
Expand Down
17 changes: 5 additions & 12 deletions Libraries/MLXLLM/Models/GPTOSS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,8 @@ class AttentionBlock: Module {
if sinksActive {
fatalError("Quantized attention does not support non-zero sinks.")
}
if qcache.offset == 0 {
q = rope(q)
k = rope(k)
} else {
q = rope(q, offset: qcache.offset)
k = rope(k, offset: qcache.offset)
}
q = applyRotaryPosition(rope, to: q, cache: cache)
k = applyRotaryPosition(rope, to: k, cache: cache)

let (qKeys, qValues) = qcache.updateQuantized(keys: k, values: v)
let vHat = quantizedScaledDotProductAttention(
Expand All @@ -252,13 +247,11 @@ class AttentionBlock: Module {
return oProj(vHat.swappedAxes(1, 2).reshaped(B, L, -1))
}

q = applyRotaryPosition(rope, to: q, cache: cache)
k = applyRotaryPosition(rope, to: k, cache: cache)

if let cache {
q = rope(q, offset: cache.offset)
k = rope(k, offset: cache.offset)
(k, v) = cache.update(keys: k, values: v)
} else {
q = rope(q)
k = rope(k)
}

let vHat = MLXFast.scaledDotProductAttention(
Expand Down
9 changes: 2 additions & 7 deletions Libraries/MLXLLM/Models/Gemma.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,8 @@ class GemmaAttention: Module {
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries)
keys = rope(keys)
}
queries = applyRotaryPosition(rope, to: queries, cache: cache)
keys = applyRotaryPosition(rope, to: keys, cache: cache)

let output = attentionWithCacheUpdate(
queries: queries,
Expand Down
Loading
Loading