Skip to content

Commit 91562b3

Browse files
committed
feat: add vllm cahche
1 parent 4311477 commit 91562b3

File tree

8 files changed

+371
-6
lines changed

8 files changed

+371
-6
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Example demonstrating how to use VLLMKVCacheMemory with vLLM backend.
4+
This example shows how to use the new vLLM-compatible KV cache memory.
5+
"""
6+
7+
from memos.configs.memory import MemoryConfigFactory
8+
from memos.memories.factory import MemoryFactory
9+
10+
11+
def main():
12+
"""Main function demonstrating VLLMKVCacheMemory usage."""
13+
14+
print("=== VLLM KV Cache Memory Example ===\n")
15+
16+
# 1. Create config for VLLMKVCacheMemory (using vLLM backend)
17+
config = MemoryConfigFactory(
18+
backend="vllm_kv_cache", # Use the new vLLM KV cache backend
19+
config={
20+
"extractor_llm": {
21+
"backend": "vllm",
22+
"config": {
23+
"model_name_or_path": "/mnt/afs/models/hf_models/Qwen2.5-7B",
24+
"api_base": "http://localhost:8088/v1",
25+
"temperature": 0.7,
26+
"max_tokens": 1024,
27+
"model_schema": "memos.configs.llm.VLLMLLMConfig",
28+
},
29+
},
30+
},
31+
)
32+
33+
# 2. Instantiate VLLMKVCacheMemory using the factory
34+
print("Initializing VLLM KV Cache Memory...")
35+
vllm_kv_mem = MemoryFactory.from_config(config)
36+
print("✓ VLLM KV Cache Memory initialized successfully.\n")
37+
38+
# 3. Extract a VLLMKVCacheItem from a prompt
39+
print("===== Extract VLLMKVCacheItem =====")
40+
system_prompt = [
41+
{"role": "system", "content": "You are a helpful AI assistant."},
42+
{"role": "user", "content": "What is MemOS?"},
43+
{"role": "assistant", "content": "MemOS is a memory operating system for LLMs."},
44+
]
45+
46+
try:
47+
cache_item = vllm_kv_mem.extract(system_prompt)
48+
print("✓ KV cache item extracted successfully")
49+
print(f" ID: {cache_item.id}")
50+
print(f" Memory (prompt): {cache_item.memory[:100]}...")
51+
print(f" Metadata: {cache_item.metadata}")
52+
print()
53+
except Exception as e:
54+
print(f"✗ Failed to extract KV cache item: {e}")
55+
return
56+
57+
# 4. Add the extracted VLLMKVCacheItem
58+
print("===== Add VLLMKVCacheItem =====")
59+
vllm_kv_mem.add([cache_item])
60+
all_items = vllm_kv_mem.get_all()
61+
print(f"✓ Added cache item. Total items: {len(all_items)}")
62+
print()
63+
64+
# 5. Get by id
65+
print("===== Get VLLMKVCacheItem by id =====")
66+
retrieved = vllm_kv_mem.get(cache_item.id)
67+
if retrieved:
68+
print(f"✓ Retrieved cache item: {retrieved.id}")
69+
print(f" Memory (prompt): {retrieved.memory[:100]}...")
70+
else:
71+
print("✗ Failed to retrieve cache item")
72+
print()
73+
74+
# 6. Get cache (returns prompt string for vLLM)
75+
print("===== Get Cache (Prompt String) =====")
76+
prompt_string = vllm_kv_mem.get_cache([cache_item.id])
77+
if prompt_string:
78+
print(f"✓ Retrieved prompt string: {prompt_string[:100]}...")
79+
print(" This prompt can be used for vLLM generation with preloaded KV cache")
80+
else:
81+
print("✗ Failed to retrieve prompt string")
82+
print()
83+
84+
# 7. Extract another cache item for demonstration
85+
print("===== Extract Another VLLMKVCacheItem =====")
86+
another_prompt = [
87+
{"role": "system", "content": "You are a coding assistant."},
88+
{"role": "user", "content": "Write a Python function to calculate fibonacci numbers."},
89+
]
90+
91+
try:
92+
cache_item2 = vllm_kv_mem.extract(another_prompt)
93+
vllm_kv_mem.add([cache_item2])
94+
print(f"✓ Added second cache item. Total items: {len(vllm_kv_mem.get_all())}")
95+
print()
96+
except Exception as e:
97+
print(f"✗ Failed to extract second KV cache item: {e}")
98+
print()
99+
100+
# 8. Preload KV cache on vLLM server
101+
print("===== Preload KV Cache on vLLM Server =====")
102+
try:
103+
vllm_kv_mem.preload_kv_cache([cache_item.id, cache_item2.id])
104+
print("✓ KV cache preloaded on vLLM server successfully")
105+
print(" The server now has the KV cache ready for fast generation")
106+
except Exception as e:
107+
print(f"✗ Failed to preload KV cache: {e}")
108+
print()
109+
110+
# 9. Delete one item
111+
print("===== Delete One VLLMKVCacheItem =====")
112+
vllm_kv_mem.delete([cache_item.id])
113+
remaining_items = vllm_kv_mem.get_all()
114+
print(f"✓ Deleted cache item. Remaining items: {len(remaining_items)}")
115+
print()
116+
117+
# 10. Dump and load
118+
print("===== Dump and Load VLLMKVCacheMemory =====")
119+
try:
120+
vllm_kv_mem.dump("tmp/vllm_kv_mem")
121+
print("✓ Memory dumped to 'tmp/vllm_kv_mem'")
122+
123+
# Clear memory and reload
124+
vllm_kv_mem.delete_all()
125+
vllm_kv_mem.load("tmp/vllm_kv_mem")
126+
reloaded_items = vllm_kv_mem.get_all()
127+
print(f"✓ Memory loaded from 'tmp/vllm_kv_mem': {len(reloaded_items)} items")
128+
except Exception as e:
129+
print(f"✗ Failed to dump/load memory: {e}")
130+
print()
131+
132+
print("=== Example completed successfully ===")
133+
134+
135+
if __name__ == "__main__":
136+
main()

src/memos/api/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_activation_config() -> dict[str, Any]:
7979
def get_activation_vllm_config() -> dict[str, Any]:
8080
"""Get Ollama configuration."""
8181
return {
82-
"backend": "kv_cache",
82+
"backend": "vllm_kv_cache",
8383
"config": {
8484
"memory_filename": "activation_memory.pickle",
8585
"extractor_llm": {
@@ -121,6 +121,7 @@ def get_scheduler_config() -> dict[str, Any]:
121121
"MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true"
122122
).lower()
123123
== "true",
124+
"enable_act_memory_update": True,
124125
},
125126
}
126127

src/memos/configs/mem_cube.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def validate_text_mem(cls, text_mem: MemoryConfigFactory) -> MemoryConfigFactory
7070
@classmethod
7171
def validate_act_mem(cls, act_mem: MemoryConfigFactory) -> MemoryConfigFactory:
7272
"""Validate the act_mem field."""
73-
allowed_backends = ["kv_cache", "uninitialized"]
73+
allowed_backends = ["kv_cache", "vllm_kv_cache", "uninitialized"]
7474
if act_mem.backend not in allowed_backends:
7575
raise ConfigurationError(
7676
f"GeneralMemCubeConfig requires act_mem backend to be one of {allowed_backends}, got '{act_mem.backend}'"

src/memos/configs/memory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class MemoryConfigFactory(BaseConfig):
181181
"general_text": GeneralTextMemoryConfig,
182182
"tree_text": TreeTextMemoryConfig,
183183
"kv_cache": KVCacheMemoryConfig,
184+
"vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache
184185
"lora": LoRAMemoryConfig,
185186
"uninitialized": UninitializedMemoryConfig,
186187
}

src/memos/mem_os/product.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from memos.mem_cube.general import GeneralMemCube
1515
from memos.mem_os.core import MOSCore
1616
from memos.mem_os.utils.format_utils import (
17-
convert_activation_memory_to_serializable,
1817
convert_graph_to_tree_forworkmem,
1918
filter_nodes_by_tree_ids,
2019
remove_embedding_recursive,
@@ -903,12 +902,11 @@ def get_all(
903902
)
904903
elif memory_type == "para_mem":
905904
act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all()
906-
# Convert activation memory to serializable format
907-
serializable_act_mem = convert_activation_memory_to_serializable(act_mem_params)
905+
logger.info(f"act_mem_params: {act_mem_params}")
908906
reformat_memory_list.append(
909907
{
910908
"cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0],
911-
"memories": serializable_act_mem,
909+
"memories": act_mem_params[0].model_dump(),
912910
}
913911
)
914912
return reformat_memory_list

src/memos/memories/activation/item.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ class KVCacheItem(ActivationMemoryItem):
3535

3636
model_config = ConfigDict(arbitrary_types_allowed=True) # To allow DynamicCache as a field type
3737
records: KVCacheRecords = KVCacheRecords()
38+
39+
40+
class VLLMKVCacheItem(KVCacheItem):
41+
"""
42+
VLLM KV Cache Item that stores prompt strings instead of DynamicCache objects.
43+
This is because vLLM handles KV cache on the server side via preloading.
44+
"""
45+
46+
# Override memory field to store prompt string instead of DynamicCache
47+
memory: str = Field(
48+
default="",
49+
description="Prompt string used to preload KV cache in vLLM server",
50+
)

0 commit comments

Comments
 (0)