Skip to content

Commit 0cd600d

Browse files
authored
quantized cache path (#379)
1 parent 24219cd commit 0cd600d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

Libraries/MLXLLM/Models/GPTOSS.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,34 @@ private class AttentionBlock: Module {
283283
var k = kProj(x).reshaped(B, L, -1, D).swappedAxes(1, 2)
284284
var v = vProj(x).reshaped(B, L, -1, D).swappedAxes(1, 2)
285285

286+
// Quantized cache path
287+
if let qcache = cache as? QuantizedKVCacheProtocol {
288+
if qcache.offset == 0 {
289+
q = rope(q)
290+
k = rope(k)
291+
292+
let zeros = MLXArray.zeros([B, Hk, 1, D]).asType(k.dtype)
293+
k = concatenated([zeros, k], axis: 2)
294+
v = concatenated([zeros, v], axis: 2)
295+
} else {
296+
q = rope(q, offset: qcache.offset - 1)
297+
k = rope(k, offset: qcache.offset - 1)
298+
}
299+
300+
let (qKeys, qValues) = qcache.updateQuantized(keys: k, values: v)
301+
let vHat = quantizedScaledDotProductAttention(
302+
queries: q,
303+
quantizedKeys: qKeys,
304+
quantizedValues: qValues,
305+
scale: smScale,
306+
mask: .array(mask),
307+
groupSize: qcache.groupSize,
308+
bits: qcache.bits
309+
)
310+
311+
return oProj(vHat.swappedAxes(1, 2).reshaped(B, L, -1))
312+
}
313+
286314
if cache == nil || cache?.offset == 0 {
287315
q = rope(q)
288316
k = rope(k)

0 commit comments

Comments
 (0)