Skip to content

Commit aa11489

Browse files
committed
fix: fix wrong param in the whole repo
1 parent 4809d32 commit aa11489

File tree

4 files changed

+129
-63
lines changed

4 files changed

+129
-63
lines changed

src/memos/llms/hf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,21 +383,22 @@ def build_kv_cache(self, messages) -> DynamicCache:
383383
with torch.no_grad():
384384
self.model(**inputs, use_cache=True, past_key_values=kv)
385385
try:
386-
if hasattr(kv, "key_cache") and hasattr(kv, "value_cache"):
387-
for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache)):
388-
if isinstance(k, torch.Tensor):
389-
kv.key_cache[i] = k[..., :seq_len, :]
390-
if isinstance(v, torch.Tensor):
391-
kv.value_cache[i] = v[..., :seq_len, :]
392-
elif hasattr(kv, "layers"):
386+
# Prefer new API first
387+
if hasattr(kv, "layers") and kv.layers is not None:
393388
for layer in kv.layers:
394389
if hasattr(layer, "keys") and isinstance(layer.keys, torch.Tensor):
395390
layer.keys = layer.keys[..., :seq_len, :]
396391
if hasattr(layer, "values") and isinstance(layer.values, torch.Tensor):
397392
layer.values = layer.values[..., :seq_len, :]
393+
elif hasattr(kv, "key_cache") and hasattr(kv, "value_cache"):
394+
for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache)):
395+
if isinstance(k, torch.Tensor):
396+
kv.key_cache[i] = k[..., :seq_len, :]
397+
if isinstance(v, torch.Tensor):
398+
kv.value_cache[i] = v[..., :seq_len, :]
398399
else:
399400
logger.warning(
400-
"DynamicCache object has no key_cache/value_cache or layers attributes; returning unmodified cache"
401+
"DynamicCache object has no layers or key_cache/value_cache attributes; returning unmodified cache"
401402
)
402403
except Exception as e:
403404
logger.exception("Failed while trimming KV cache to seq_len: %s", e)

src/memos/mem_os/utils/format_utils.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,35 +1088,68 @@ def convert_activation_memory_to_serializable(
10881088

10891089
for item in act_mem_items:
10901090
# Extract basic information that can be serialized
1091+
# Infer counts/device/dtype compatibly for new/old DynamicCache APIs
1092+
mem = item.memory
1093+
key_layers = 0
1094+
val_layers = 0
1095+
device_str = "unknown"
1096+
dtype_str = "unknown"
1097+
1098+
if mem:
1099+
if hasattr(mem, "layers") and mem.layers is not None:
1100+
key_layers = len(mem.layers)
1101+
val_layers = len(mem.layers)
1102+
# find first available tensor to report device/dtype
1103+
for lyr in mem.layers:
1104+
t = getattr(lyr, "keys", None)
1105+
if t is None:
1106+
t = getattr(lyr, "values", None)
1107+
if t is not None:
1108+
device_str = str(t.device)
1109+
dtype_str = str(t.dtype)
1110+
break
1111+
else:
1112+
key_layers = len(getattr(mem, "key_cache", []) or [])
1113+
val_layers = len(getattr(mem, "value_cache", []) or [])
1114+
if getattr(mem, "key_cache", None):
1115+
first = next((t for t in mem.key_cache if t is not None), None)
1116+
if first is not None:
1117+
device_str = str(first.device)
1118+
dtype_str = str(first.dtype)
1119+
10911120
serializable_item = {
10921121
"id": item.id,
10931122
"metadata": item.metadata,
10941123
"memory_info": {
10951124
"type": "DynamicCache",
1096-
"key_cache_layers": len(item.memory.key_cache) if item.memory else 0,
1097-
"value_cache_layers": len(item.memory.value_cache) if item.memory else 0,
1098-
"device": str(item.memory.key_cache[0].device)
1099-
if item.memory and item.memory.key_cache
1100-
else "unknown",
1101-
"dtype": str(item.memory.key_cache[0].dtype)
1102-
if item.memory and item.memory.key_cache
1103-
else "unknown",
1125+
"key_cache_layers": key_layers,
1126+
"value_cache_layers": val_layers,
1127+
"device": device_str,
1128+
"dtype": dtype_str,
11041129
},
11051130
}
11061131

11071132
# Add tensor shape information if available
1108-
if item.memory and item.memory.key_cache:
1133+
if item.memory:
11091134
key_shapes = []
11101135
value_shapes = []
1111-
1112-
for i, key_tensor in enumerate(item.memory.key_cache):
1113-
if key_tensor is not None:
1114-
key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
1115-
1116-
if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None:
1117-
value_shapes.append(
1118-
{"layer": i, "shape": list(item.memory.value_cache[i].shape)}
1119-
)
1136+
mem = item.memory
1137+
if hasattr(mem, "layers") and mem.layers is not None:
1138+
for i, layer in enumerate(mem.layers):
1139+
if getattr(layer, "keys", None) is not None:
1140+
key_shapes.append({"layer": i, "shape": list(layer.keys.shape)})
1141+
if getattr(layer, "values", None) is not None:
1142+
value_shapes.append({"layer": i, "shape": list(layer.values.shape)})
1143+
elif getattr(mem, "key_cache", None):
1144+
for i, key_tensor in enumerate(mem.key_cache):
1145+
if key_tensor is not None:
1146+
key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
1147+
if (
1148+
hasattr(mem, "value_cache")
1149+
and i < len(mem.value_cache)
1150+
and mem.value_cache[i] is not None
1151+
):
1152+
value_shapes.append({"layer": i, "shape": list(mem.value_cache[i].shape)})
11201153

11211154
serializable_item["memory_info"]["key_shapes"] = key_shapes
11221155
serializable_item["memory_info"]["value_shapes"] = value_shapes
@@ -1144,15 +1177,22 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[
11441177
total_parameters = 0
11451178

11461179
for item in act_mem_items:
1147-
if item.memory and item.memory.key_cache:
1148-
total_layers += len(item.memory.key_cache)
1149-
1150-
# Calculate approximate parameter count
1151-
for key_tensor in item.memory.key_cache:
1180+
mem = item.memory
1181+
if not mem:
1182+
continue
1183+
if hasattr(mem, "layers") and mem.layers is not None:
1184+
total_layers += len(mem.layers)
1185+
for layer in mem.layers:
1186+
if getattr(layer, "keys", None) is not None:
1187+
total_parameters += layer.keys.numel()
1188+
if getattr(layer, "values", None) is not None:
1189+
total_parameters += layer.values.numel()
1190+
elif getattr(mem, "key_cache", None):
1191+
total_layers += len(mem.key_cache)
1192+
for key_tensor in mem.key_cache:
11521193
if key_tensor is not None:
11531194
total_parameters += key_tensor.numel()
1154-
1155-
for value_tensor in item.memory.value_cache:
1195+
for value_tensor in getattr(mem, "value_cache", []) or []:
11561196
if value_tensor is not None:
11571197
total_parameters += value_tensor.numel()
11581198

src/memos/memories/activation/kv.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import pickle
33

44
from datetime import datetime
5-
from importlib.metadata import version
6-
7-
from packaging.version import Version
85
from transformers import DynamicCache
96

107
from memos.configs.memory import KVCacheMemoryConfig
@@ -210,29 +207,26 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
210207
if len(caches) == 1:
211208
return caches[0]
212209

213-
merged = DynamicCache()
214-
num_layers = len(caches[0].key_cache)
215-
216-
if Version(version("transformers")) >= Version("4.54.0"):
217-
merged.append_new_layers(num_layers - 1)
210+
# Newer transformers expose `layers` with `.keys`/`.values`
211+
if hasattr(caches[0], "layers") and caches[0].layers is not None:
212+
num_layers = len(caches[0].layers)
213+
base = caches[0]
218214
for layer in range(num_layers):
219-
# gather all K and V for this layer
220215
keys = [c.layers[layer].keys for c in caches]
221216
vals = [c.layers[layer].values for c in caches]
222-
# single concat per layer
223-
merged.layers[layer].keys = torch.cat(keys, dim=-2)
224-
merged.layers[layer].values = torch.cat(vals, dim=-2)
225-
217+
base.layers[layer].keys = torch.cat(keys, dim=-2)
218+
base.layers[layer].values = torch.cat(vals, dim=-2)
219+
return base
226220
else:
221+
# Legacy API: key_cache/value_cache lists
222+
merged = DynamicCache()
223+
num_layers = len(caches[0].key_cache)
227224
for layer in range(num_layers):
228-
# gather all K and V for this layer
229225
keys = [c.key_cache[layer] for c in caches]
230226
vals = [c.value_cache[layer] for c in caches]
231-
# single concat per layer
232227
merged.key_cache.append(torch.cat(keys, dim=-2))
233228
merged.value_cache.append(torch.cat(vals, dim=-2))
234-
235-
return merged
229+
return merged
236230

237231

238232
def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache:
@@ -242,11 +236,27 @@ def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> Dynamic
242236
So before inferring with DynamicCache, we should move it to GPU in-place first.
243237
"""
244238
# Currently, we put this function outside [class KVCacheMemory]
245-
for i in range(len(dynamic_cache.key_cache)):
246-
if dynamic_cache.key_cache[i] is not None:
247-
dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to(device, non_blocking=True)
248-
if dynamic_cache.value_cache[i] is not None:
249-
dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to(
250-
device, non_blocking=True
251-
)
239+
# Support both old API (key_cache/value_cache) and new API (layers with keys/values)
240+
if hasattr(dynamic_cache, "layers") and dynamic_cache.layers is not None:
241+
for i, layer in enumerate(dynamic_cache.layers):
242+
# Each layer is expected to have `.keys` and `.values` tensors
243+
if hasattr(layer, "keys") and layer.keys is not None:
244+
layer.keys = layer.keys.to(device, non_blocking=True)
245+
if hasattr(layer, "values") and layer.values is not None:
246+
layer.values = layer.values.to(device, non_blocking=True)
247+
else:
248+
# Fallback to legacy attributes
249+
for i in range(len(getattr(dynamic_cache, "key_cache", []))):
250+
if dynamic_cache.key_cache[i] is not None:
251+
dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to(
252+
device, non_blocking=True
253+
)
254+
if (
255+
hasattr(dynamic_cache, "value_cache")
256+
and i < len(dynamic_cache.value_cache)
257+
and dynamic_cache.value_cache[i] is not None
258+
):
259+
dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to(
260+
device, non_blocking=True
261+
)
252262
return dynamic_cache

tests/memories/activation/test_kv.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,20 @@ def kv_memory(dummy_config):
3434

3535

3636
def make_filled_cache():
37-
# Create a DynamicCache with at least one dummy tensor layer
37+
# Create a DynamicCache with at least one dummy tensor layer, supporting new/old APIs
3838
cache = DynamicCache()
39-
cache.key_cache.append(torch.zeros(1, 2, 3))
40-
cache.value_cache.append(torch.zeros(1, 2, 3))
39+
if hasattr(cache, "layers") and cache.layers is not None:
40+
# For new API, append a layer-like object with keys/values tensors
41+
class _Layer:
42+
def __init__(self):
43+
self.keys = torch.zeros(1, 2, 3)
44+
self.values = torch.zeros(1, 2, 3)
45+
46+
cache.layers.append(_Layer())
47+
else:
48+
# Legacy API
49+
cache.key_cache.append(torch.zeros(1, 2, 3))
50+
cache.value_cache.append(torch.zeros(1, 2, 3))
4151
return cache
4252

4353

@@ -58,9 +68,14 @@ def test_get_cache_merge(kv_memory):
5868
kv_memory.add([item1, item2])
5969
merged = kv_memory.get_cache([item1.id, item2.id])
6070
assert isinstance(merged, DynamicCache)
61-
# Check the number of layers in merged key/value cache
62-
assert len(merged.key_cache) == 1
63-
assert len(merged.value_cache) == 1
71+
# Check the number of layers in merged cache (new or old API)
72+
if hasattr(merged, "layers") and merged.layers is not None:
73+
assert len(merged.layers) == 1
74+
assert getattr(merged.layers[0], "keys", None) is not None
75+
assert getattr(merged.layers[0], "values", None) is not None
76+
else:
77+
assert len(merged.key_cache) == 1
78+
assert len(merged.value_cache) == 1
6479

6580

6681
def test_delete_and_get_all(kv_memory):

0 commit comments

Comments
 (0)