@@ -283,6 +283,34 @@ private class AttentionBlock: Module {
283283 var k = kProj ( x) . reshaped ( B, L, - 1 , D) . swappedAxes ( 1 , 2 )
284284 var v = vProj ( x) . reshaped ( B, L, - 1 , D) . swappedAxes ( 1 , 2 )
285285
286+ // Quantized cache path
287+ if let qcache = cache as? QuantizedKVCacheProtocol {
288+ if qcache. offset == 0 {
289+ q = rope ( q)
290+ k = rope ( k)
291+
292+ let zeros = MLXArray . zeros ( [ B, Hk, 1 , D] ) . asType ( k. dtype)
293+ k = concatenated ( [ zeros, k] , axis: 2 )
294+ v = concatenated ( [ zeros, v] , axis: 2 )
295+ } else {
296+ q = rope ( q, offset: qcache. offset - 1 )
297+ k = rope ( k, offset: qcache. offset - 1 )
298+ }
299+
300+ let ( qKeys, qValues) = qcache. updateQuantized ( keys: k, values: v)
301+ let vHat = quantizedScaledDotProductAttention (
302+ queries: q,
303+ quantizedKeys: qKeys,
304+ quantizedValues: qValues,
305+ scale: smScale,
306+ mask: . array( mask) ,
307+ groupSize: qcache. groupSize,
308+ bits: qcache. bits
309+ )
310+
311+ return oProj ( vHat. swappedAxes ( 1 , 2 ) . reshaped ( B, L, - 1 ) )
312+ }
313+
286314 if cache == nil || cache? . offset == 0 {
287315 q = rope ( q)
288316 k = rope ( k)
0 commit comments