Skip to content

Commit 2daa6f2

Browse files
authored
Merge branch 'test' into feat/nebula_update
2 parents a020369 + beb0e07 commit 2daa6f2

File tree

10 files changed

+594
-12
lines changed

10 files changed

+594
-12
lines changed

src/memos/embedders/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from memos.embedders.ollama import OllamaEmbedder
77
from memos.embedders.sentence_transformer import SenTranEmbedder
88
from memos.embedders.universal_api import UniversalAPIEmbedder
9+
from memos.memos_tools.singleton import singleton_factory
910

1011

1112
class EmbedderFactory(BaseEmbedder):
@@ -19,6 +20,7 @@ class EmbedderFactory(BaseEmbedder):
1920
}
2021

2122
@classmethod
23+
@singleton_factory()
2224
def from_config(cls, config_factory: EmbedderConfigFactory) -> BaseEmbedder:
2325
backend = config_factory.backend
2426
if backend not in cls.backend_to_class:

src/memos/llms/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from memos.llms.openai import AzureLLM, OpenAILLM
1010
from memos.llms.qwen import QwenLLM
1111
from memos.llms.vllm import VLLMLLM
12+
from memos.memos_tools.singleton import singleton_factory
1213

1314

1415
class LLMFactory(BaseLLM):
@@ -26,6 +27,7 @@ class LLMFactory(BaseLLM):
2627
}
2728

2829
@classmethod
30+
@singleton_factory()
2931
def from_config(cls, config_factory: LLMConfigFactory) -> BaseLLM:
3032
backend = config_factory.backend
3133
if backend not in cls.backend_to_class:

src/memos/mem_os/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from memos.memories.activation.item import ActivationMemoryItem
2525
from memos.memories.parametric.item import ParametricMemoryItem
2626
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
27-
from memos.memos_tools.thread_safe_dict import ThreadSafeDict
27+
from memos.memos_tools.thread_safe_dict_segment import OptimizedThreadSafeDict
2828
from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
2929
from memos.types import ChatHistory, MessageList, MOSSearchResult
3030

@@ -47,8 +47,8 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None):
4747
self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
4848
self.chat_history_manager: dict[str, ChatHistory] = {}
4949
# use thread safe dict for multi-user product-server scenario
50-
self.mem_cubes: ThreadSafeDict[str, GeneralMemCube] = (
51-
ThreadSafeDict() if user_manager is not None else {}
50+
self.mem_cubes: OptimizedThreadSafeDict[str, GeneralMemCube] = (
51+
OptimizedThreadSafeDict() if user_manager is not None else {}
5252
)
5353
self._register_chat_history()
5454

src/memos/mem_os/product.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,10 @@ def register_mem_cube(
796796
logger.info(
797797
f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}"
798798
)
799+
time_start = time.time()
799800
self.mem_cubes[mem_cube_id] = mem_cube
801+
time_end = time.time()
802+
logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}")
800803

801804
def user_register(
802805
self,
@@ -847,13 +850,14 @@ def user_register(
847850
cube_path=mem_cube_name_or_path,
848851
cube_id=mem_cube_id,
849852
)
850-
853+
time_start = time.time()
851854
if default_mem_cube:
852855
try:
853-
default_mem_cube.dump(mem_cube_name_or_path)
856+
default_mem_cube.dump(mem_cube_name_or_path, memory_types=[])
854857
except Exception as e:
855858
logger.error(f"Failed to dump default cube: {e}")
856-
859+
time_end = time.time()
860+
logger.info(f"time user_register: dump default cube time is: {time_end - time_start}")
857861
# Register the default cube with MOS
858862
self.register_mem_cube(
859863
mem_cube_name_or_path_or_object=default_mem_cube,
@@ -1316,9 +1320,14 @@ def search(
13161320
# Load user cubes if not already loaded
13171321
time_start = time.time()
13181322
self._load_user_cubes(user_id, self.default_cube_config)
1319-
dict_size = sys.getsizeof(self.mem_cubes._dict)
1320-
size_mb = dict_size / (1024 * 1024)
1321-
logger.info(f"now search memcubes_size is : {len(self.mem_cubes)} {size_mb}MB")
1323+
try:
1324+
dict_size = sys.getsizeof(self.mem_cubes)
1325+
size_mb = dict_size / (1024 * 1024)
1326+
logger.info(
1327+
f"now search memcubes_size is : len is {len(self.mem_cubes)} and {size_mb}MB"
1328+
)
1329+
except Exception as e:
1330+
logger.warning(f"Failed to get memcubes size: {e}, ignore it")
13221331
load_user_cubes_time_end = time.time()
13231332
logger.info(
13241333
f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
@@ -1367,9 +1376,12 @@ def add(
13671376

13681377
# Load user cubes if not already loaded
13691378
self._load_user_cubes(user_id, self.default_cube_config)
1370-
dict_size = sys.getsizeof(self.mem_cubes._dict)
1371-
size_mb = dict_size / (1024 * 1024)
1372-
logger.info(f"now add memcubes_size is : {len(self.mem_cubes)} {size_mb}MB")
1379+
try:
1380+
dict_size = sys.getsizeof(self.mem_cubes)
1381+
size_mb = dict_size / (1024 * 1024)
1382+
logger.info(f"now add memcubes_size is : {len is len(self.mem_cubes)} and {size_mb}MB")
1383+
except Exception as e:
1384+
logger.warning(f"Failed to get memcubes size: {e}, ignore it")
13731385
result = super().add(
13741386
messages, memory_content, doc_path, mem_cube_id, user_id, session_id=session_id
13751387
)

src/memos/mem_reader/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from memos.configs.mem_reader import MemReaderConfigFactory
44
from memos.mem_reader.base import BaseMemReader
55
from memos.mem_reader.simple_struct import SimpleStructMemReader
6+
from memos.memos_tools.singleton import singleton_factory
67

78

89
class MemReaderFactory(BaseMemReader):
@@ -13,6 +14,7 @@ class MemReaderFactory(BaseMemReader):
1314
}
1415

1516
@classmethod
17+
@singleton_factory()
1618
def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader:
1719
backend = config_factory.backend
1820
if backend not in cls.backend_to_class:

src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
InternetGoogleRetriever,
1111
)
1212
from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
13+
from memos.memos_tools.singleton import singleton_factory
1314

1415

1516
class InternetRetrieverFactory:
@@ -23,6 +24,7 @@ class InternetRetrieverFactory:
2324
}
2425

2526
@classmethod
27+
@singleton_factory()
2628
def from_config(
2729
cls, config_factory: InternetRetrieverConfigFactory, embedder: BaseEmbedder
2830
) -> InternetGoogleRetriever | None:

src/memos/memos_tools/singleton.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Singleton decorator module for caching factory instances to avoid excessive memory usage
3+
from repeated initialization.
4+
"""
5+
6+
import hashlib
7+
import json
8+
9+
from collections.abc import Callable
10+
from functools import wraps
11+
from typing import Any, TypeVar
12+
from weakref import WeakValueDictionary
13+
14+
15+
T = TypeVar("T")
16+
17+
18+
class FactorySingleton:
19+
"""Factory singleton manager that caches instances based on configuration parameters"""
20+
21+
def __init__(self):
22+
# Use weak reference dictionary for automatic cleanup when instances are no longer referenced
23+
self._instances: dict[str, WeakValueDictionary] = {}
24+
25+
def _generate_cache_key(self, config: Any, *args, **kwargs) -> str:
26+
"""Generate cache key based on configuration only (ignoring other parameters)"""
27+
28+
# Handle configuration objects - only use the config parameter
29+
if hasattr(config, "model_dump"): # Pydantic model
30+
config_data = config.model_dump()
31+
elif hasattr(config, "dict"): # Legacy Pydantic model
32+
config_data = config.dict()
33+
elif isinstance(config, dict):
34+
config_data = config
35+
else:
36+
# For other types, try to convert to string
37+
config_data = str(config)
38+
39+
# Filter out time-related fields that shouldn't affect caching
40+
filtered_config = self._filter_temporal_fields(config_data)
41+
42+
# Generate hash key based only on config
43+
try:
44+
cache_str = json.dumps(filtered_config, sort_keys=True, ensure_ascii=False, default=str)
45+
except (TypeError, ValueError):
46+
# If JSON serialization fails, convert the entire config to string
47+
cache_str = str(filtered_config)
48+
49+
return hashlib.md5(cache_str.encode("utf-8")).hexdigest()
50+
51+
def _filter_temporal_fields(self, config_data: Any) -> Any:
52+
"""Filter out temporal fields that shouldn't affect instance caching"""
53+
if isinstance(config_data, dict):
54+
filtered = {}
55+
for key, value in config_data.items():
56+
# Skip common temporal field names
57+
if key.lower() in {
58+
"created_at",
59+
"updated_at",
60+
"timestamp",
61+
"time",
62+
"date",
63+
"created_time",
64+
"updated_time",
65+
"last_modified",
66+
"modified_at",
67+
"start_time",
68+
"end_time",
69+
"execution_time",
70+
"run_time",
71+
}:
72+
continue
73+
# Recursively filter nested dictionaries
74+
filtered[key] = self._filter_temporal_fields(value)
75+
return filtered
76+
elif isinstance(config_data, list):
77+
# Recursively filter lists
78+
return [self._filter_temporal_fields(item) for item in config_data]
79+
else:
80+
# For primitive types, return as-is
81+
return config_data
82+
83+
def get_or_create(self, factory_class: type, cache_key: str, creator_func: Callable) -> Any:
84+
"""Get or create instance"""
85+
class_name = factory_class.__name__
86+
87+
if class_name not in self._instances:
88+
self._instances[class_name] = WeakValueDictionary()
89+
90+
class_cache = self._instances[class_name]
91+
92+
if cache_key in class_cache:
93+
return class_cache[cache_key]
94+
95+
# Create new instance
96+
instance = creator_func()
97+
class_cache[cache_key] = instance
98+
return instance
99+
100+
def clear_cache(self, factory_class: type | None = None):
101+
"""Clear cache"""
102+
if factory_class:
103+
class_name = factory_class.__name__
104+
if class_name in self._instances:
105+
self._instances[class_name].clear()
106+
else:
107+
for cache in self._instances.values():
108+
cache.clear()
109+
110+
111+
# Global singleton manager
112+
_factory_singleton = FactorySingleton()
113+
114+
115+
def singleton_factory(factory_class: type | str | None = None):
116+
"""
117+
Factory singleton decorator
118+
119+
Usage:
120+
@singleton_factory()
121+
def from_config(cls, config):
122+
return SomeClass(config)
123+
124+
Or specify factory class:
125+
@singleton_factory(EmbedderFactory)
126+
def from_config(cls, config):
127+
return SomeClass(config)
128+
"""
129+
130+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
131+
@wraps(func)
132+
def wrapper(*args, **kwargs) -> T:
133+
# Determine factory class and config parameter
134+
target_factory_class = factory_class
135+
config = None
136+
137+
# Simple logic: check if first parameter is a class or config
138+
if args:
139+
if hasattr(args[0], "__name__") and hasattr(args[0], "__module__"):
140+
# First parameter is a class (cls), so this is a @classmethod
141+
if target_factory_class is None:
142+
target_factory_class = args[0]
143+
config = args[1] if len(args) > 1 else None
144+
else:
145+
# First parameter is config, so this is a @staticmethod
146+
if target_factory_class is None:
147+
raise ValueError(
148+
"Factory class must be explicitly specified for static methods"
149+
)
150+
if isinstance(target_factory_class, str):
151+
# Convert string to a mock class for caching purposes
152+
class MockFactoryClass:
153+
__name__ = target_factory_class
154+
155+
target_factory_class = MockFactoryClass
156+
config = args[0]
157+
158+
if config is None:
159+
# If no configuration parameter, call original function directly
160+
return func(*args, **kwargs)
161+
162+
# Generate cache key based only on config
163+
cache_key = _factory_singleton._generate_cache_key(config)
164+
165+
# Function to create instance
166+
def creator():
167+
return func(*args, **kwargs)
168+
169+
# Get or create instance
170+
return _factory_singleton.get_or_create(target_factory_class, cache_key, creator)
171+
172+
return wrapper
173+
174+
return decorator

0 commit comments

Comments
 (0)