|
| 1 | +import json |
| 2 | +import os |
| 3 | + |
| 4 | +from dotenv import load_dotenv |
| 5 | + |
| 6 | +from memos import log |
| 7 | +from memos.configs.mem_cube import GeneralMemCubeConfig |
| 8 | +from memos.configs.mem_os import MOSConfig |
| 9 | +from memos.mem_cube.general import GeneralMemCube |
| 10 | +from memos.mem_os.product import MOSProduct |
| 11 | + |
| 12 | + |
| 13 | +load_dotenv() |
| 14 | + |
| 15 | + |
| 16 | +logger = log.get_logger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +# === Load conversation === |
| 20 | +with open("evaluation/data/locomo/locomo10.json", encoding="utf-8") as f: |
| 21 | + conversation = json.load(f) |
| 22 | + data = conversation[3] |
| 23 | + speaker_a = data["conversation"]["speaker_a"] |
| 24 | + speaker_b = data["conversation"]["speaker_b"] |
| 25 | + conversation_i = data["conversation"] |
| 26 | + |
| 27 | +db_name = "shared-db-locomo-case" |
| 28 | + |
| 29 | +openapi_config = { |
| 30 | + "model_name_or_path": "gpt-4o-mini", |
| 31 | + "temperature": 0.8, |
| 32 | + "max_tokens": 1024, |
| 33 | + "api_key": "your-api-key-here", |
| 34 | + "api_base": "https://api.openai.com/v1", |
| 35 | +} |
| 36 | + |
| 37 | + |
| 38 | +# === Create MOS Config === |
| 39 | +def get_user_configs(user_name): |
| 40 | + mos_config = MOSConfig( |
| 41 | + user_id=user_name, |
| 42 | + chat_model={"backend": "openai", "config": openapi_config}, |
| 43 | + mem_reader={ |
| 44 | + "backend": "simple_struct", |
| 45 | + "config": { |
| 46 | + "llm": {"backend": "openai", "config": openapi_config}, |
| 47 | + "embedder": { |
| 48 | + "backend": "universal_api", |
| 49 | + "config": { |
| 50 | + "provider": "openai", |
| 51 | + "api_key": openapi_config["api_key"], |
| 52 | + "model_name_or_path": "text-embedding-3-large", |
| 53 | + "base_url": openapi_config["api_base"], |
| 54 | + }, |
| 55 | + }, |
| 56 | + "chunker": { |
| 57 | + "backend": "sentence", |
| 58 | + "config": { |
| 59 | + "tokenizer_or_token_counter": "gpt2", |
| 60 | + "chunk_size": 512, |
| 61 | + "chunk_overlap": 128, |
| 62 | + "min_sentences_per_chunk": 1, |
| 63 | + }, |
| 64 | + }, |
| 65 | + }, |
| 66 | + }, |
| 67 | + enable_textual_memory=True, |
| 68 | + enable_activation_memory=False, |
| 69 | + enable_parametric_memory=False, |
| 70 | + top_k=5, |
| 71 | + max_turns_window=20, |
| 72 | + ) |
| 73 | + |
| 74 | + return mos_config |
| 75 | + |
| 76 | + |
| 77 | +# === Get Memory Cube Config === |
| 78 | +def get_mem_cube_config(user_name): |
| 79 | + neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| 80 | + neo4j_config = { |
| 81 | + "uri": neo4j_uri, |
| 82 | + "user": "neo4j", |
| 83 | + "password": "12345678", |
| 84 | + "db_name": db_name, |
| 85 | + "user_name": "will be updated", |
| 86 | + "use_multi_db": False, |
| 87 | + "embedding_dimension": 3072, |
| 88 | + "auto_create": True, |
| 89 | + } |
| 90 | + cube_config = GeneralMemCubeConfig.model_validate( |
| 91 | + { |
| 92 | + "user_id": user_name, |
| 93 | + "cube_id": f"{user_name}_cube", |
| 94 | + "text_mem": { |
| 95 | + "backend": "tree_text", |
| 96 | + "config": { |
| 97 | + "extractor_llm": {"backend": "openai", "config": openapi_config}, |
| 98 | + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, |
| 99 | + "graph_db": {"backend": "neo4j", "config": neo4j_config}, |
| 100 | + "embedder": { |
| 101 | + "backend": "universal_api", |
| 102 | + "config": { |
| 103 | + "provider": "openai", |
| 104 | + "api_key": openapi_config["api_key"], |
| 105 | + "model_name_or_path": "text-embedding-3-large", |
| 106 | + "base_url": openapi_config["api_base"], |
| 107 | + }, |
| 108 | + }, |
| 109 | + "reorganize": True, |
| 110 | + }, |
| 111 | + }, |
| 112 | + } |
| 113 | + ) |
| 114 | + |
| 115 | + mem_cube = GeneralMemCube(cube_config) |
| 116 | + return mem_cube |
| 117 | + |
| 118 | + |
| 119 | +# === Initialize MOSProduct === |
| 120 | +root_config = get_user_configs(user_name="system") |
| 121 | +mos_product = MOSProduct(default_config=root_config) |
| 122 | + |
| 123 | + |
| 124 | +# === Register both users === |
| 125 | +users = {} |
| 126 | +for speaker in [speaker_a, speaker_b]: |
| 127 | + user_id = speaker.lower() + "_test" |
| 128 | + config = get_user_configs(user_id) |
| 129 | + mem_cube = get_mem_cube_config(user_id) |
| 130 | + result = mos_product.user_register( |
| 131 | + user_id=user_id, |
| 132 | + user_name=speaker, |
| 133 | + interests=f"I'm {speaker}", |
| 134 | + default_mem_cube=mem_cube, |
| 135 | + ) |
| 136 | + users[speaker] = {"user_id": user_id, "default_cube_id": result["default_cube_id"]} |
| 137 | + print(f"✅ Registered: {speaker} -> {result}") |
| 138 | + |
| 139 | +# === Process conversation, add to both roles === |
| 140 | +i = 1 |
| 141 | +MAX_CONVERSATION_FOR_TEST = 3 |
| 142 | +while ( |
| 143 | + f"session_{i}_date_time" in conversation_i and f"session_{i}" in conversation_i |
| 144 | +) and i < MAX_CONVERSATION_FOR_TEST: |
| 145 | + session_i = conversation_i[f"session_{i}"] |
| 146 | + session_time = conversation_i[f"session_{i}_date_time"] |
| 147 | + |
| 148 | + print(f"\n=== Processing Session {i} | Time: {session_time} ===") |
| 149 | + |
| 150 | + role1_msgs, role2_msgs = [], [] |
| 151 | + |
| 152 | + for m in session_i: |
| 153 | + if m["speaker"] == speaker_a: |
| 154 | + role1_msgs.append( |
| 155 | + { |
| 156 | + "role": "user", |
| 157 | + "content": f"{m['speaker']}:{m['text']}", |
| 158 | + "chat_time": session_time, |
| 159 | + } |
| 160 | + ) |
| 161 | + role2_msgs.append( |
| 162 | + { |
| 163 | + "role": "assistant", |
| 164 | + "content": f"{m['speaker']}:{m['text']}", |
| 165 | + "chat_time": session_time, |
| 166 | + } |
| 167 | + ) |
| 168 | + elif m["speaker"] == speaker_b: |
| 169 | + role1_msgs.append( |
| 170 | + { |
| 171 | + "role": "assistant", |
| 172 | + "content": f"{m['speaker']}:{m['text']}", |
| 173 | + "chat_time": session_time, |
| 174 | + } |
| 175 | + ) |
| 176 | + role2_msgs.append( |
| 177 | + { |
| 178 | + "role": "user", |
| 179 | + "content": f"{m['speaker']}:{m['text']}", |
| 180 | + "chat_time": session_time, |
| 181 | + } |
| 182 | + ) |
| 183 | + |
| 184 | + print(f"\n[Session {i}] {speaker_a} will add {len(role1_msgs)} messages.") |
| 185 | + print(f"[Session {i}] {speaker_b} will add {len(role2_msgs)} messages.") |
| 186 | + |
| 187 | + mos_product.add( |
| 188 | + user_id=users[speaker_a]["user_id"], |
| 189 | + messages=role1_msgs, |
| 190 | + mem_cube_id=users[speaker_a]["default_cube_id"], |
| 191 | + ) |
| 192 | + mos_product.add( |
| 193 | + user_id=users[speaker_b]["user_id"], |
| 194 | + messages=role2_msgs, |
| 195 | + mem_cube_id=users[speaker_b]["default_cube_id"], |
| 196 | + ) |
| 197 | + |
| 198 | + print(f"[Session {i}] Added messages for both roles") |
| 199 | + |
| 200 | + i += 1 |
| 201 | + |
| 202 | +print("\n✅ All messages added for both roles.\n") |
| 203 | +mos_product.mem_reorganizer_off() |
0 commit comments