Skip to content

Commit f2788b6

Browse files
committed
fix bug when calling _concat_caches in kv.py
1 parent d131bce commit f2788b6

File tree

1 file changed

+6
-4
lines changed
  • src/memos/memories/activation

1 file changed

+6
-4
lines changed

src/memos/memories/activation/kv.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,15 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
211211
merged = DynamicCache()
212212
num_layers = len(caches[0].key_cache)
213213

214+
merged.append_new_layers(num_layers - 1)
215+
214216
for layer in range(num_layers):
215217
# gather all K and V for this layer
216-
keys = [c.key_cache[layer] for c in caches]
217-
vals = [c.value_cache[layer] for c in caches]
218+
keys = [c.layers[layer].keys for c in caches]
219+
vals = [c.layers[layer].values for c in caches]
218220
# single concat per layer
219-
merged.key_cache.append(torch.cat(keys, dim=-2))
220-
merged.value_cache.append(torch.cat(vals, dim=-2))
221+
merged.layers[layer].keys = (torch.cat(keys, dim=-2))
222+
merged.layers[layer].values = (torch.cat(vals, dim=-2))
221223

222224
return merged
223225

0 commit comments

Comments
 (0)