Skip to content

Commit 807885c

Browse files
authored
Add copy() to KVCache protocol and all implementations (#158)
* Add copy() to KVCache protocol and all implementations Add an independent deep-copy method to enable reusing a prefix cache across multiple ChatSession instances without reloading from disk. - Add copy() requirement to KVCache protocol - Implement on KVCacheSimple, RotatingKVCache, QuantizedKVCache, ChunkedKVCache, ArraysCache, MambaCache, and CacheList - Guard against empty state in all copy() methods to avoid fatalError from state setters that reject empty arrays - Preserve leftPadding in ArraysCache/MambaCache copies - Add CacheList array-based initializer to support copy() - Change ArraysCache.leftPadding from private to internal for subclass access in MambaCache.copy() Tests: - testCacheCopyIsIndependent: parameterized across 6 cache types, verifies copy has same state and mutation of copy leaves original unchanged - testCacheCopyOnEmptyCache: verifies copy of unpopulated cache does not crash - testCacheListCopyIsIndependent: verifies CacheList with heterogeneous sub-caches copies independently * Bump mlx-swift dependency to 0.31.1 Picks up the fix for array[.ellipsis] returning self instead of a copy (ml-explore/mlx-swift#367), plus mlx 0.31.1 C++ updates.
1 parent f7a235d commit 807885c

File tree

3 files changed

+226
-10
lines changed

3 files changed

+226
-10
lines changed

Libraries/MLXLMCommon/KVCache.swift

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ public protocol KVCache: Evaluatable {
7171
func makeMask(
7272
n: Int, windowSize: Int?, returnArray: Bool
7373
) -> MLXFast.ScaledDotProductAttentionMaskMode
74+
75+
/// Create an independent deep copy of this cache.
76+
func copy() -> any KVCache
7477
}
7578

7679
/// Protocol for caches that support efficient quantized operations
@@ -149,6 +152,10 @@ open class BaseKVCache: KVCache {
149152
@discardableResult
150153
open func trim(_ n: Int) -> Int { 0 }
151154

155+
open func copy() -> any KVCache {
156+
fatalError("copy() must be implemented by subclass")
157+
}
158+
152159
/// Default implementation for caches without special mask requirements
153160
open func makeMask(
154161
n: Int, windowSize: Int?, returnArray: Bool
@@ -419,6 +426,16 @@ public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible {
419426
return quantizedCache
420427
}
421428

429+
public override func copy() -> any KVCache {
430+
let new = KVCacheSimple()
431+
new.step = self.step
432+
let s = self.state
433+
if !s.isEmpty {
434+
new.state = s.map { $0[.ellipsis] }
435+
}
436+
return new
437+
}
438+
422439
public var debugDescription: String {
423440
"\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), step: \(step), keys: \(keys?.shape.description ?? "-"), values: \(values?.shape.description ?? "-")"
424441
}
@@ -680,6 +697,16 @@ public class RotatingKVCache: BaseKVCache, CustomDebugStringConvertible {
680697
"\(String(describing: Self.self)) offset: \(offset), maxSize: \(maxCacheSize.description), keep: \(keep), idx: \(idx)"
681698
}
682699

700+
public override func copy() -> any KVCache {
701+
let new = RotatingKVCache(maxSize: maxCacheSize, keep: keep, step: step)
702+
let s = self.state
703+
if !s.isEmpty {
704+
new.state = s.map { $0[.ellipsis] }
705+
}
706+
new.metaState = self.metaState
707+
return new
708+
}
709+
683710
/// Convert to quantized cache
684711
/// Note: This is complex due to the rotating nature and temporal ordering
685712
public func toQuantized(groupSize: Int = 64, bits: Int = 4) -> QuantizedKVCache {
@@ -925,6 +952,16 @@ public class QuantizedKVCache: BaseKVCache, QuantizedKVCacheProtocol {
925952
return trimmed
926953
}
927954

955+
public override func copy() -> any KVCache {
956+
let new = QuantizedKVCache(groupSize: groupSize, bits: bits, mode: mode)
957+
let s = self.state
958+
if !s.isEmpty {
959+
new.state = s.map { $0[.ellipsis] }
960+
}
961+
new.metaState = self.metaState
962+
return new
963+
}
964+
928965
/// Convert to unquantized cache
929966
public func toUnquantized() -> KVCacheSimple {
930967
let simpleCache = KVCacheSimple()
@@ -1014,6 +1051,17 @@ public class ChunkedKVCache: KVCacheSimple {
10141051
return trimmed
10151052
}
10161053

1054+
public override func copy() -> any KVCache {
1055+
let new = ChunkedKVCache(chunkSize: chunkSize)
1056+
new.step = self.step
1057+
let s = self.state
1058+
if !s.isEmpty {
1059+
new.state = s.map { $0[.ellipsis] }
1060+
}
1061+
new.metaState = self.metaState
1062+
return new
1063+
}
1064+
10171065
public override var metaState: [String] {
10181066
get {
10191067
let chunkSizeStr = chunkSize?.description ?? "None"
@@ -1036,7 +1084,7 @@ public class ChunkedKVCache: KVCacheSimple {
10361084
/// Base cache for array-based state storage
10371085
public class ArraysCache: BaseKVCache {
10381086
private var cache: [MLXArray?]
1039-
private var leftPadding: MLXArray?
1087+
internal var leftPadding: MLXArray?
10401088

10411089
public init(size: Int, leftPadding: [Int]? = nil) {
10421090
self.cache = Array(repeating: nil, count: size)
@@ -1062,6 +1110,17 @@ public class ArraysCache: BaseKVCache {
10621110
}
10631111
}
10641112

1113+
public override func copy() -> any KVCache {
1114+
let new = ArraysCache(size: cache.count)
1115+
let s = self.state
1116+
if !s.isEmpty {
1117+
new.state = s.map { $0[.ellipsis] }
1118+
}
1119+
new.offset = self.offset
1120+
new.leftPadding = self.leftPadding
1121+
return new
1122+
}
1123+
10651124
/// In-place filter to keep just the given indices in the cache
10661125
public func filter(batchIndices: MLXArray) {
10671126
cache = cache.map { c in
@@ -1096,6 +1155,17 @@ public class MambaCache: ArraysCache {
10961155
public init(leftPadding: [Int]? = nil) {
10971156
super.init(size: 2, leftPadding: leftPadding)
10981157
}
1158+
1159+
public override func copy() -> any KVCache {
1160+
let new = MambaCache()
1161+
let s = self.state
1162+
if !s.isEmpty {
1163+
new.state = s.map { $0[.ellipsis] }
1164+
}
1165+
new.offset = self.offset
1166+
new.leftPadding = self.leftPadding
1167+
return new
1168+
}
10991169
}
11001170

11011171
/// Composite cache that manages multiple sub-caches
@@ -1107,6 +1177,11 @@ public class CacheList: BaseKVCache {
11071177
super.init()
11081178
}
11091179

1180+
public init(_ caches: [any KVCache]) {
1181+
self.caches = caches
1182+
super.init()
1183+
}
1184+
11101185
public override func innerState() -> [MLXArray] {
11111186
caches.flatMap { $0.innerState() }
11121187
}
@@ -1132,6 +1207,12 @@ public class CacheList: BaseKVCache {
11321207
}
11331208
}
11341209

1210+
public override func copy() -> any KVCache {
1211+
let copiedCaches = caches.map { $0.copy() }
1212+
let new = CacheList(copiedCaches)
1213+
return new
1214+
}
1215+
11351216
public override var isTrimmable: Bool {
11361217
caches.allSatisfy { $0.isTrimmable }
11371218
}

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ let package = Package(
2626
targets: ["MLXEmbedders"]),
2727
],
2828
dependencies: [
29-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.6")),
29+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")),
3030
.package(
3131
url: "https://github.com/huggingface/swift-transformers",
3232
.upToNextMinor(from: "1.2.0")

Tests/MLXLMTests/KVCacheTests.swift

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@ import MLX
33
import MLXLMCommon
44
import Testing
55

6+
private let cacheCreators: [() -> any KVCache] = [
7+
{ KVCacheSimple() },
8+
{ RotatingKVCache(maxSize: 32) },
9+
{ QuantizedKVCache() },
10+
{ ChunkedKVCache(chunkSize: 16) },
11+
{ ArraysCache(size: 2) },
12+
{ MambaCache() },
13+
]
14+
615
@Test(
716
.serialized,
8-
arguments: [
9-
({ KVCacheSimple() }),
10-
({ RotatingKVCache(maxSize: 32) }),
11-
({ QuantizedKVCache() }),
12-
({ ChunkedKVCache(chunkSize: 16) }),
13-
({ ArraysCache(size: 2) }),
14-
({ MambaCache() }),
15-
])
17+
arguments: cacheCreators)
1618
func testCacheSerialization(creator: (() -> any KVCache)) async throws {
1719
let cache = (0 ..< 10).map { _ in creator() }
1820
let keys = MLXArray.ones([1, 8, 32, 64], dtype: .bfloat16)
@@ -43,3 +45,136 @@ func testCacheSerialization(creator: (() -> any KVCache)) async throws {
4345
#expect(lhs.state.count == rhs.state.count)
4446
}
4547
}
48+
49+
/// Verify that copy() produces an independent cache: same type, same state,
50+
/// but mutating the copy does not affect the original.
51+
@Test(
52+
.serialized,
53+
arguments: cacheCreators)
54+
func testCacheCopyIsIndependent(creator: (() -> any KVCache)) async throws {
55+
let original = creator()
56+
57+
let keys = MLXArray.ones([1, 8, 4, 64], dtype: .bfloat16)
58+
let values = MLXArray.ones([1, 8, 4, 64], dtype: .bfloat16)
59+
60+
// populate the original
61+
switch original {
62+
case let arrays as ArraysCache:
63+
arrays[0] = keys
64+
arrays[1] = values
65+
case let quantized as QuantizedKVCache:
66+
_ = quantized.updateQuantized(keys: keys, values: values)
67+
default:
68+
_ = original.update(keys: keys, values: values)
69+
}
70+
71+
let originalOffset = original.offset
72+
let originalState = original.state
73+
eval(originalState)
74+
let originalMeta = original.metaState
75+
76+
// copy
77+
let copied = original.copy()
78+
79+
// same type
80+
#expect(type(of: original) == type(of: copied))
81+
82+
// same offset and metadata
83+
#expect(copied.offset == originalOffset)
84+
#expect(copied.metaState == originalMeta)
85+
86+
// same state values
87+
let copiedState = copied.state
88+
eval(copiedState)
89+
#expect(copiedState.count == originalState.count)
90+
for (origArr, copyArr) in zip(originalState, copiedState) {
91+
#expect(origArr.shape == copyArr.shape)
92+
#expect(allClose(origArr, copyArr).item(Bool.self))
93+
}
94+
95+
// mutate the copy — push more tokens through it
96+
let moreKeys = MLXArray.zeros([1, 8, 2, 64], dtype: .bfloat16)
97+
let moreValues = MLXArray.zeros([1, 8, 2, 64], dtype: .bfloat16)
98+
99+
switch copied {
100+
case let arrays as ArraysCache:
101+
// overwrite slot 0 with a different array
102+
arrays[0] = moreKeys
103+
case let quantized as QuantizedKVCache:
104+
_ = quantized.updateQuantized(keys: moreKeys, values: moreValues)
105+
default:
106+
_ = copied.update(keys: moreKeys, values: moreValues)
107+
}
108+
109+
// original must be unchanged
110+
#expect(original.offset == originalOffset)
111+
#expect(original.metaState == originalMeta)
112+
let currentState = original.state
113+
eval(currentState)
114+
#expect(currentState.count == originalState.count)
115+
for (origArr, savedArr) in zip(currentState, originalState) {
116+
#expect(origArr.shape == savedArr.shape)
117+
#expect(allClose(origArr, savedArr).item(Bool.self))
118+
}
119+
}
120+
121+
/// copy() on an empty (unpopulated) cache must not crash.
122+
@Test(
123+
.serialized,
124+
arguments: cacheCreators)
125+
func testCacheCopyOnEmptyCache(creator: (() -> any KVCache)) async throws {
126+
let empty = creator()
127+
let copied = empty.copy()
128+
129+
#expect(type(of: empty) == type(of: copied))
130+
#expect(copied.offset == 0)
131+
#expect(copied.state.count == empty.state.count)
132+
}
133+
134+
/// CacheList.copy() produces independent sub-caches.
135+
@Test
136+
func testCacheListCopyIsIndependent() async throws {
137+
let sub1 = KVCacheSimple()
138+
let sub2 = RotatingKVCache(maxSize: 32)
139+
let composite = CacheList(sub1, sub2)
140+
141+
let keys = MLXArray.ones([1, 8, 4, 64], dtype: .bfloat16)
142+
let values = MLXArray.ones([1, 8, 4, 64], dtype: .bfloat16)
143+
_ = sub1.update(keys: keys, values: values)
144+
_ = sub2.update(keys: keys, values: values)
145+
146+
// snapshot original state — eval to materialize before copy
147+
let originalState = composite.state
148+
eval(originalState)
149+
let originalOffset0 = sub1.offset
150+
let originalOffset1 = sub2.offset
151+
152+
let copied = composite.copy()
153+
154+
#expect(copied is CacheList)
155+
let copiedState = copied.state
156+
eval(copiedState)
157+
#expect(copiedState.count == originalState.count)
158+
for (orig, copy) in zip(originalState, copiedState) {
159+
#expect(orig.shape == copy.shape)
160+
#expect(allClose(orig, copy).item(Bool.self))
161+
}
162+
163+
// mutate inside the copy
164+
let copiedList = copied as! CacheList
165+
_ = copiedList[0].update(
166+
keys: MLXArray.zeros([1, 8, 2, 64], dtype: .bfloat16),
167+
values: MLXArray.zeros([1, 8, 2, 64], dtype: .bfloat16)
168+
)
169+
170+
// originals unchanged
171+
#expect(sub1.offset == originalOffset0)
172+
#expect(sub2.offset == originalOffset1)
173+
let currentState = composite.state
174+
eval(currentState)
175+
#expect(currentState.count == originalState.count)
176+
for (orig, saved) in zip(currentState, originalState) {
177+
#expect(orig.shape == saved.shape)
178+
#expect(allClose(orig, saved).item(Bool.self))
179+
}
180+
}

0 commit comments

Comments
 (0)