Skip to content

Commit fb9e539

Browse files
authored
fix: fix bug when calling _concat_caches in kv.py (from pr#177)
fix bug when calling _concat_caches in kv.py from pr#177 by @RixinLiu.
1 parent cd99dc5 commit fb9e539

File tree

1 file changed

+5
-4
lines changed
  • src/memos/memories/activation

1 file changed

+5
-4
lines changed

src/memos/memories/activation/kv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,15 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
210210

211211
merged = DynamicCache()
212212
num_layers = len(caches[0].key_cache)
213+
merged.append_new_layers(num_layers - 1)
213214

214215
for layer in range(num_layers):
215216
# gather all K and V for this layer
216-
keys = [c.key_cache[layer] for c in caches]
217-
vals = [c.value_cache[layer] for c in caches]
217+
keys = [c.layers[layer].keys for c in caches]
218+
vals = [c.layers[layer].values for c in caches]
218219
# single concat per layer
219-
merged.key_cache.append(torch.cat(keys, dim=-2))
220-
merged.value_cache.append(torch.cat(vals, dim=-2))
220+
merged.layers[layer].keys = (torch.cat(keys, dim=-2))
221+
merged.layers[layer].values = (torch.cat(vals, dim=-2))
221222

222223
return merged
223224

0 commit comments

Comments
 (0)