Skip to content

Commit ec3d657

Browse files
committed
feat & refactor: enable mem scheduler to load auth config from environment variables, refactor AuthConfig and EnvConfigMixin for improved robustness and smarter configuration handling, and allow the mem scheduler to initialize modules with RabbitMQ support.
1 parent e8346fc commit ec3d657

File tree

13 files changed

+432
-165
lines changed

13 files changed

+432
-165
lines changed

evaluation/.env-example

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,24 @@ ZEP_API_KEY="z_***REDACTED***"
99
CHAT_MODEL="gpt-4o-mini"
1010
CHAT_MODEL_BASE_URL="http://***.***.***.***:3000/v1"
1111
CHAT_MODEL_API_KEY="sk-***REDACTED***"
12+
13+
# Configuration Only For Scheduler
14+
# RabbitMQ Configuration
15+
MEMSCHEDULER_RABBITMQ_HOST_NAME=rabbitmq-cn-***.cn-***.amqp-32.net.mq.amqp.aliyuncs.com
16+
MEMSCHEDULER_RABBITMQ_USER_NAME=***
17+
MEMSCHEDULER_RABBITMQ_PASSWORD=***
18+
MEMSCHEDULER_RABBITMQ_VIRTUAL_HOST=memos
19+
MEMSCHEDULER_RABBITMQ_ERASE_ON_CONNECT=true
20+
MEMSCHEDULER_RABBITMQ_PORT=5672
21+
22+
# OpenAI Configuration
23+
MEMSCHEDULER_OPENAI_API_KEY=sk-***
24+
MEMSCHEDULER_OPENAI_BASE_URL=http://***.***.***.***:3000/v1
25+
MEMSCHEDULER_OPENAI_DEFAULT_MODEL=gpt-4o-mini
26+
27+
# Graph DB Configuration
28+
MEMSCHEDULER_GRAPHDBAUTH_URI=bolt://localhost:7687
29+
MEMSCHEDULER_GRAPHDBAUTH_USER=neo4j
30+
MEMSCHEDULER_GRAPHDBAUTH_PASSWORD=***
31+
MEMSCHEDULER_GRAPHDBAUTH_DB_NAME=neo4j
32+
MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true

evaluation/scripts/temporal_locomo/modules/base_eval_module.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -91,51 +91,53 @@ def __init__(self, args):
9191
self.ingestion_storage_dir = self.result_dir / "storages"
9292
self.mos_config_path = Path(f"{BASE_DIR}/configs-example/mos_w_scheduler_config.json")
9393
self.mem_cube_config_path = Path(f"{BASE_DIR}/configs-example/mem_cube_config.json")
94+
9495
self.openai_api_key = os.getenv("CHAT_MODEL_API_KEY")
9596
self.openai_base_url = os.getenv("CHAT_MODEL_BASE_URL")
9697
self.openai_chat_model = os.getenv("CHAT_MODEL")
9798

9899
auth_config_path = Path(f"{BASE_DIR}/scripts/temporal_locomo/eval_auth.json")
99100
if auth_config_path.exists():
100101
auth_config = AuthConfig.from_local_config(config_path=auth_config_path)
101-
102-
self.openai_api_key = auth_config.openai.api_key
103-
self.openai_base_url = auth_config.openai.base_url
104-
self.openai_chat_model = auth_config.openai.default_model
105-
106-
self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8"))
107-
self.mem_cube_config_data = json.load(
108-
self.mem_cube_config_path.open("r", encoding="utf-8")
109-
)
110-
111-
# Update LLM authentication information in MOS configuration using dictionary assignment
112-
self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = (
113-
auth_config.openai.api_key
114-
)
115-
self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = (
116-
auth_config.openai.base_url
117-
)
118-
119-
# Update graph database authentication information in memory cube configuration using dictionary assignment
120-
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = (
121-
auth_config.graph_db.uri
122-
)
123-
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = (
124-
auth_config.graph_db.user
125-
)
126-
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = (
127-
auth_config.graph_db.password
128-
)
129-
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
130-
auth_config.graph_db.db_name
102+
print(
103+
f"✅ Configuration loaded successfully: from local config file {auth_config_path}"
131104
)
132-
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = (
133-
auth_config.graph_db.auto_create
134-
)
135-
136105
else:
137-
print("Please referring to configs-example to provide valid configs.")
138-
exit()
106+
# Load .env file first before reading environment variables
107+
load_dotenv()
108+
auth_config = AuthConfig.from_local_env()
109+
print("✅ Configuration loaded successfully: from environment variables")
110+
self.openai_api_key = auth_config.openai.api_key
111+
self.openai_base_url = auth_config.openai.base_url
112+
self.openai_chat_model = auth_config.openai.default_model
113+
114+
self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8"))
115+
self.mem_cube_config_data = json.load(self.mem_cube_config_path.open("r", encoding="utf-8"))
116+
117+
# Update LLM authentication information in MOS configuration using dictionary assignment
118+
self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = (
119+
auth_config.openai.api_key
120+
)
121+
self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = (
122+
auth_config.openai.base_url
123+
)
124+
125+
# Update graph database authentication information in memory cube configuration using dictionary assignment
126+
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = (
127+
auth_config.graph_db.uri
128+
)
129+
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = (
130+
auth_config.graph_db.user
131+
)
132+
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = (
133+
auth_config.graph_db.password
134+
)
135+
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
136+
auth_config.graph_db.db_name
137+
)
138+
self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = (
139+
auth_config.graph_db.auto_create
140+
)
139141

140142
# Logger initialization
141143
self.logger = logger
@@ -158,7 +160,6 @@ def __init__(self, args):
158160

159161
self.can_answer_cases: list[RecordingCase] = []
160162
self.cannot_answer_cases: list[RecordingCase] = []
161-
load_dotenv()
162163

163164
def print_eval_info(self):
164165
"""

src/memos/configs/mem_scheduler.py

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23

34
from pathlib import Path
@@ -135,7 +136,7 @@ class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
135136
password: str = Field(
136137
default="",
137138
description="Password for graph database authentication",
138-
min_length=8, # 建议密码最小长度
139+
min_length=8, # Recommended minimum password length
139140
)
140141
db_name: str = Field(default="neo4j", description="Database name to connect to")
141142
auto_create: bool = Field(
@@ -150,13 +151,51 @@ class OpenAIConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
150151

151152

152153
class AuthConfig(BaseConfig, DictConversionMixin):
153-
rabbitmq: RabbitMQConfig
154-
openai: OpenAIConfig
155-
graph_db: GraphDBAuthConfig
154+
rabbitmq: RabbitMQConfig | None = None
155+
openai: OpenAIConfig | None = None
156+
graph_db: GraphDBAuthConfig | None = None
156157
default_config_path: ClassVar[str] = (
157158
f"{BASE_DIR}/examples/data/config/mem_scheduler/scheduler_auth.yaml"
158159
)
159160

161+
@model_validator(mode="after")
162+
def validate_partial_initialization(self) -> "AuthConfig":
163+
"""
164+
Validate that at least one configuration component is successfully initialized.
165+
Log warnings for any failed initializations but allow partial success.
166+
"""
167+
logger = logging.getLogger(__name__)
168+
169+
initialized_components = []
170+
failed_components = []
171+
172+
if self.rabbitmq is not None:
173+
initialized_components.append("rabbitmq")
174+
else:
175+
failed_components.append("rabbitmq")
176+
177+
if self.openai is not None:
178+
initialized_components.append("openai")
179+
else:
180+
failed_components.append("openai")
181+
182+
if self.graph_db is not None:
183+
initialized_components.append("graph_db")
184+
else:
185+
failed_components.append("graph_db")
186+
187+
# Allow all components to be None for flexibility, but log a warning
188+
if not initialized_components:
189+
logger.warning(
190+
"All configuration components are None. This may indicate missing environment variables or configuration files."
191+
)
192+
elif failed_components:
193+
logger.warning(
194+
f"Failed to initialize components: {', '.join(failed_components)}. Successfully initialized: {', '.join(initialized_components)}"
195+
)
196+
197+
return self
198+
160199
@classmethod
161200
def from_local_config(cls, config_path: str | Path | None = None) -> "AuthConfig":
162201
"""
@@ -205,24 +244,75 @@ def from_local_env(cls) -> "AuthConfig":
205244
206245
This method loads configuration for all nested components (RabbitMQ, OpenAI, GraphDB)
207246
from their respective environment variables using each component's specific prefix.
247+
If any component fails to initialize, it will be set to None and a warning will be logged.
208248
209249
Returns:
210250
AuthConfig: Configured instance with values from environment variables
211251
212252
Raises:
213-
ValueError: If any required environment variables are missing
253+
ValueError: If all components fail to initialize
214254
"""
255+
logger = logging.getLogger(__name__)
256+
257+
rabbitmq_config = None
258+
openai_config = None
259+
graph_db_config = None
260+
261+
# Try to initialize RabbitMQ config - check if any RabbitMQ env vars exist
262+
try:
263+
rabbitmq_prefix = RabbitMQConfig.get_env_prefix()
264+
has_rabbitmq_env = any(key.startswith(rabbitmq_prefix) for key in os.environ)
265+
if has_rabbitmq_env:
266+
rabbitmq_config = RabbitMQConfig.from_env()
267+
logger.info("Successfully initialized RabbitMQ configuration")
268+
else:
269+
logger.info(
270+
"No RabbitMQ environment variables found, skipping RabbitMQ initialization"
271+
)
272+
except (ValueError, Exception) as e:
273+
logger.warning(f"Failed to initialize RabbitMQ config from environment: {e}")
274+
275+
# Try to initialize OpenAI config - check if any OpenAI env vars exist
276+
try:
277+
openai_prefix = OpenAIConfig.get_env_prefix()
278+
has_openai_env = any(key.startswith(openai_prefix) for key in os.environ)
279+
if has_openai_env:
280+
openai_config = OpenAIConfig.from_env()
281+
logger.info("Successfully initialized OpenAI configuration")
282+
else:
283+
logger.info("No OpenAI environment variables found, skipping OpenAI initialization")
284+
except (ValueError, Exception) as e:
285+
logger.warning(f"Failed to initialize OpenAI config from environment: {e}")
286+
287+
# Try to initialize GraphDB config - check if any GraphDB env vars exist
288+
try:
289+
graphdb_prefix = GraphDBAuthConfig.get_env_prefix()
290+
has_graphdb_env = any(key.startswith(graphdb_prefix) for key in os.environ)
291+
if has_graphdb_env:
292+
graph_db_config = GraphDBAuthConfig.from_env()
293+
logger.info("Successfully initialized GraphDB configuration")
294+
else:
295+
logger.info(
296+
"No GraphDB environment variables found, skipping GraphDB initialization"
297+
)
298+
except (ValueError, Exception) as e:
299+
logger.warning(f"Failed to initialize GraphDB config from environment: {e}")
300+
215301
return cls(
216-
rabbitmq=RabbitMQConfig.from_env(),
217-
openai=OpenAIConfig.from_env(),
218-
graph_db=GraphDBAuthConfig.from_env(),
302+
rabbitmq=rabbitmq_config,
303+
openai=openai_config,
304+
graph_db=graph_db_config,
219305
)
220306

221307
def set_openai_config_to_environment(self):
222-
# Set environment variables
223-
os.environ["OPENAI_API_KEY"] = self.openai.api_key
224-
os.environ["OPENAI_BASE_URL"] = self.openai.base_url
225-
os.environ["MODEL"] = self.openai.default_model
308+
# Set environment variables only if openai config is available
309+
if self.openai is not None:
310+
os.environ["OPENAI_API_KEY"] = self.openai.api_key
311+
os.environ["OPENAI_BASE_URL"] = self.openai.base_url
312+
os.environ["MODEL"] = self.openai.default_model
313+
else:
314+
logger = logging.getLogger(__name__)
315+
logger.warning("OpenAI config is not available, skipping environment variable setup")
226316

227317
@classmethod
228318
def default_config_exists(cls) -> bool:

src/memos/mem_os/core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,6 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler:
124124
f"Memory reader of type {type(self.mem_reader).__name__} "
125125
"missing required 'llm' attribute"
126126
)
127-
self._mem_scheduler.initialize_modules(
128-
chat_llm=self.chat_llm,
129-
process_llm=self.chat_llm,
130-
db_engine=self.user_manager.engine,
131-
)
132127
else:
133128
# Configure scheduler general_modules
134129
self._mem_scheduler.initialize_modules(

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def initialize_modules(
138138

139139
if self.auth_config is not None:
140140
self.rabbitmq_config = self.auth_config.rabbitmq
141-
self.initialize_rabbitmq(config=self.rabbitmq_config)
141+
if self.rabbitmq_config is not None:
142+
self.initialize_rabbitmq(config=self.rabbitmq_config)
142143

143144
logger.debug("GeneralScheduler has been initialized")
144145
except Exception as e:
@@ -497,6 +498,9 @@ def _submit_web_logs(
497498
Args:
498499
messages: Single log message or list of log messages
499500
"""
501+
if self.rabbitmq_config is None:
502+
return
503+
500504
if isinstance(messages, ScheduleLogForWebItem):
501505
messages = [messages] # transform single message to list
502506

@@ -526,7 +530,7 @@ def get_web_log_messages(self) -> list[dict]:
526530
messages = []
527531
while True:
528532
try:
529-
item = self._web_log_message_queue.get_nowait() # 线程安全的 get
533+
item = self._web_log_message_queue.get_nowait() # Thread-safe get
530534
messages.append(item.to_dict())
531535
except queue.Empty:
532536
break

src/memos/mem_scheduler/general_modules/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def join(self, timeout: float | None = None) -> bool:
206206
bool: True if all tasks completed, False if timeout occurred.
207207
"""
208208
if not self.enable_parallel_dispatch or self.dispatcher_executor is None:
209-
return True # 串行模式无需等待
209+
return True # Serial mode requires no waiting
210210

211211
done, not_done = concurrent.futures.wait(
212212
self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED

src/memos/mem_scheduler/general_modules/misc.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from queue import Empty, Full, Queue
77
from typing import TYPE_CHECKING, Any, Generic, TypeVar
88

9+
from dotenv import load_dotenv
910
from pydantic import field_serializer
1011

1112

@@ -32,7 +33,7 @@ def get_env_prefix(cls) -> str:
3233
Examples:
3334
RabbitMQConfig -> "RABBITMQ_"
3435
OpenAIConfig -> "OPENAI_"
35-
GraphDBAuthConfig -> "GRAPH_DB_AUTH_"
36+
GraphDBAuthConfig -> "GRAPHDBAUTH_"
3637
"""
3738
class_name = cls.__name__
3839
# Remove 'Config' suffix if present
@@ -55,6 +56,8 @@ def from_env(cls: type[T]) -> T:
5556
Raises:
5657
ValueError: If required environment variables are missing.
5758
"""
59+
load_dotenv()
60+
5861
prefix = cls.get_env_prefix()
5962
field_values = {}
6063

@@ -85,6 +88,35 @@ def _parse_env_value(cls, value: str, target_type: type) -> Any:
8588
return float(value)
8689
return value
8790

91+
@classmethod
92+
def print_env_mapping(cls) -> None:
93+
"""Print the mapping between class fields and their corresponding environment variable names.
94+
95+
Displays each field's name, type, whether it's required, default value, and corresponding environment variable name.
96+
"""
97+
prefix = cls.get_env_prefix()
98+
print(f"\n=== {cls.__name__} Environment Variable Mapping ===")
99+
print(f"Environment Variable Prefix: {prefix}")
100+
print("-" * 60)
101+
102+
if not hasattr(cls, "model_fields"):
103+
print("This class does not define model_fields, may not be a Pydantic model")
104+
return
105+
106+
for field_name, field_info in cls.model_fields.items():
107+
env_var = f"{prefix}{field_name.upper()}"
108+
field_type = field_info.annotation
109+
is_required = field_info.is_required()
110+
default_value = field_info.default if field_info.default is not None else "None"
111+
112+
print(f"Field Name: {field_name}")
113+
print(f" Environment Variable: {env_var}")
114+
print(f" Type: {field_type}")
115+
print(f" Required: {'Yes' if is_required else 'No'}")
116+
print(f" Default Value: {default_value}")
117+
print(f" Current Environment Value: {os.environ.get(env_var, 'Not Set')}")
118+
print("-" * 40)
119+
88120

89121
class DictConversionMixin:
90122
"""

src/memos/mem_scheduler/general_modules/scheduler_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def create_autofilled_log_item(
9898
)
9999
return log_message
100100

101-
# TODO: 日志打出来数量不对
101+
# TODO: Log output count is incorrect
102102
@log_exceptions(logger=logger)
103103
def log_working_memory_replacement(
104104
self,

0 commit comments

Comments
 (0)