diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 080962482..6617e751f 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -21,7 +21,7 @@ def __init__(self, config: SentenceChunkerConfig): self.config = config self.chunker = ChonkieSentenceChunker( - tokenizer_or_token_counter=config.tokenizer_or_token_counter, + tokenizer=config.tokenizer_or_token_counter, chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, min_sentences_per_chunk=config.min_sentences_per_chunk, diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 5fdb59058..153a5978c 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1088,35 +1088,68 @@ def convert_activation_memory_to_serializable( for item in act_mem_items: # Extract basic information that can be serialized + # Infer counts/device/dtype compatibly for new/old DynamicCache APIs + mem = item.memory + key_layers = 0 + val_layers = 0 + device_str = "unknown" + dtype_str = "unknown" + + if mem: + if hasattr(mem, "layers") and mem.layers is not None: + key_layers = len(mem.layers) + val_layers = len(mem.layers) + # find first available tensor to report device/dtype + for lyr in mem.layers: + t = getattr(lyr, "keys", None) + if t is None: + t = getattr(lyr, "values", None) + if t is not None: + device_str = str(t.device) + dtype_str = str(t.dtype) + break + else: + key_layers = len(getattr(mem, "key_cache", []) or []) + val_layers = len(getattr(mem, "value_cache", []) or []) + if getattr(mem, "key_cache", None): + first = next((t for t in mem.key_cache if t is not None), None) + if first is not None: + device_str = str(first.device) + dtype_str = str(first.dtype) + serializable_item = { "id": item.id, "metadata": item.metadata, "memory_info": { "type": "DynamicCache", - "key_cache_layers": len(item.memory.key_cache) if item.memory else 0, - "value_cache_layers": len(item.memory.value_cache) if item.memory else 0, - "device": str(item.memory.key_cache[0].device) - if item.memory and item.memory.key_cache - else "unknown", - "dtype": str(item.memory.key_cache[0].dtype) - if item.memory and item.memory.key_cache - else "unknown", + "key_cache_layers": key_layers, + "value_cache_layers": val_layers, + "device": device_str, + "dtype": dtype_str, }, } # Add tensor shape information if available - if item.memory and item.memory.key_cache: + if item.memory: key_shapes = [] value_shapes = [] - - for i, key_tensor in enumerate(item.memory.key_cache): - if key_tensor is not None: - key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) - - if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None: - value_shapes.append( - {"layer": i, "shape": list(item.memory.value_cache[i].shape)} - ) + mem = item.memory + if hasattr(mem, "layers") and mem.layers is not None: + for i, layer in enumerate(mem.layers): + if getattr(layer, "keys", None) is not None: + key_shapes.append({"layer": i, "shape": list(layer.keys.shape)}) + if getattr(layer, "values", None) is not None: + value_shapes.append({"layer": i, "shape": list(layer.values.shape)}) + elif getattr(mem, "key_cache", None): + for i, key_tensor in enumerate(mem.key_cache): + if key_tensor is not None: + key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) + if ( + hasattr(mem, "value_cache") + and i < len(mem.value_cache) + and mem.value_cache[i] is not None + ): + value_shapes.append({"layer": i, "shape": list(mem.value_cache[i].shape)}) serializable_item["memory_info"]["key_shapes"] = key_shapes serializable_item["memory_info"]["value_shapes"] = value_shapes @@ -1144,15 +1177,22 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ total_parameters = 0 for item in act_mem_items: - if item.memory and item.memory.key_cache: - total_layers += len(item.memory.key_cache) - - # Calculate approximate parameter count - for key_tensor in item.memory.key_cache: + mem = item.memory + if not mem: + continue + if hasattr(mem, "layers") and mem.layers is not None: + total_layers += len(mem.layers) + for layer in mem.layers: + if getattr(layer, "keys", None) is not None: + total_parameters += layer.keys.numel() + if getattr(layer, "values", None) is not None: + total_parameters += layer.values.numel() + elif getattr(mem, "key_cache", None): + total_layers += len(mem.key_cache) + for key_tensor in mem.key_cache: if key_tensor is not None: total_parameters += key_tensor.numel() - - for value_tensor in item.memory.value_cache: + for value_tensor in getattr(mem, "value_cache", []) or []: if value_tensor is not None: total_parameters += value_tensor.numel() diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 98d611dbf..9f7e83a27 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -2,9 +2,6 @@ import pickle from datetime import datetime -from importlib.metadata import version - -from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig @@ -210,29 +207,26 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: if len(caches) == 1: return caches[0] - merged = DynamicCache() - num_layers = len(caches[0].key_cache) - - if Version(version("transformers")) >= Version("4.54.0"): - merged.append_new_layers(num_layers - 1) + # Newer transformers expose `layers` with `.keys`/`.values` + if hasattr(caches[0], "layers") and caches[0].layers is not None: + num_layers = len(caches[0].layers) + base = caches[0] for layer in range(num_layers): - # gather all K and V for this layer keys = [c.layers[layer].keys for c in caches] vals = [c.layers[layer].values for c in caches] - # single concat per layer - merged.layers[layer].keys = torch.cat(keys, dim=-2) - merged.layers[layer].values = torch.cat(vals, dim=-2) - + base.layers[layer].keys = torch.cat(keys, dim=-2) + base.layers[layer].values = torch.cat(vals, dim=-2) + return base else: + # Legacy API: key_cache/value_cache lists + merged = DynamicCache() + num_layers = len(caches[0].key_cache) for layer in range(num_layers): - # gather all K and V for this layer keys = [c.key_cache[layer] for c in caches] vals = [c.value_cache[layer] for c in caches] - # single concat per layer merged.key_cache.append(torch.cat(keys, dim=-2)) merged.value_cache.append(torch.cat(vals, dim=-2)) - - return merged + return merged def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache: diff --git a/tests/memories/activation/test_kv.py b/tests/memories/activation/test_kv.py index 6490d687f..20bcd2435 100644 --- a/tests/memories/activation/test_kv.py +++ b/tests/memories/activation/test_kv.py @@ -34,10 +34,20 @@ def kv_memory(dummy_config): def make_filled_cache(): - # Create a DynamicCache with at least one dummy tensor layer + # Create a DynamicCache with at least one dummy tensor layer, supporting new/old APIs cache = DynamicCache() - cache.key_cache.append(torch.zeros(1, 2, 3)) - cache.value_cache.append(torch.zeros(1, 2, 3)) + if hasattr(cache, "layers") and cache.layers is not None: + # For new API, append a layer-like object with keys/values tensors + class _Layer: + def __init__(self): + self.keys = torch.zeros(1, 2, 3) + self.values = torch.zeros(1, 2, 3) + + cache.layers.append(_Layer()) + else: + # Legacy API + cache.key_cache.append(torch.zeros(1, 2, 3)) + cache.value_cache.append(torch.zeros(1, 2, 3)) return cache @@ -58,9 +68,14 @@ def test_get_cache_merge(kv_memory): kv_memory.add([item1, item2]) merged = kv_memory.get_cache([item1.id, item2.id]) assert isinstance(merged, DynamicCache) - # Check the number of layers in merged key/value cache - assert len(merged.key_cache) == 1 - assert len(merged.value_cache) == 1 + # Check the number of layers in merged cache (new or old API) + if hasattr(merged, "layers") and merged.layers is not None: + assert len(merged.layers) == 1 + assert getattr(merged.layers[0], "keys", None) is not None + assert getattr(merged.layers[0], "values", None) is not None + else: + assert len(merged.key_cache) == 1 + assert len(merged.value_cache) == 1 def test_delete_and_get_all(kv_memory):