Skip to content

Commit beb0e07

Browse files
fridayLCaralHsi
andauthored
feat: add sinlgleton (#321)
* feat: add logs for cube * feat: add loggers for mem * add dict timer log * fix: change ci code * update size * fix ci * feat: add update threading dict * fix:ci code * fix:fix mem dumps for cube * feat: add --------- Co-authored-by: CaralHsi <[email protected]>
1 parent 1dc3b2e commit beb0e07

File tree

7 files changed

+188
-0
lines changed

7 files changed

+188
-0
lines changed

src/memos/embedders/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from memos.embedders.ollama import OllamaEmbedder
77
from memos.embedders.sentence_transformer import SenTranEmbedder
88
from memos.embedders.universal_api import UniversalAPIEmbedder
9+
from memos.memos_tools.singleton import singleton_factory
910

1011

1112
class EmbedderFactory(BaseEmbedder):
@@ -19,6 +20,7 @@ class EmbedderFactory(BaseEmbedder):
1920
}
2021

2122
@classmethod
23+
@singleton_factory()
2224
def from_config(cls, config_factory: EmbedderConfigFactory) -> BaseEmbedder:
2325
backend = config_factory.backend
2426
if backend not in cls.backend_to_class:

src/memos/llms/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from memos.llms.openai import AzureLLM, OpenAILLM
1010
from memos.llms.qwen import QwenLLM
1111
from memos.llms.vllm import VLLMLLM
12+
from memos.memos_tools.singleton import singleton_factory
1213

1314

1415
class LLMFactory(BaseLLM):
@@ -26,6 +27,7 @@ class LLMFactory(BaseLLM):
2627
}
2728

2829
@classmethod
30+
@singleton_factory()
2931
def from_config(cls, config_factory: LLMConfigFactory) -> BaseLLM:
3032
backend = config_factory.backend
3133
if backend not in cls.backend_to_class:

src/memos/mem_reader/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from memos.configs.mem_reader import MemReaderConfigFactory
44
from memos.mem_reader.base import BaseMemReader
55
from memos.mem_reader.simple_struct import SimpleStructMemReader
6+
from memos.memos_tools.singleton import singleton_factory
67

78

89
class MemReaderFactory(BaseMemReader):
@@ -13,6 +14,7 @@ class MemReaderFactory(BaseMemReader):
1314
}
1415

1516
@classmethod
17+
@singleton_factory()
1618
def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader:
1719
backend = config_factory.backend
1820
if backend not in cls.backend_to_class:

src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
InternetGoogleRetriever,
1111
)
1212
from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
13+
from memos.memos_tools.singleton import singleton_factory
1314

1415

1516
class InternetRetrieverFactory:
@@ -23,6 +24,7 @@ class InternetRetrieverFactory:
2324
}
2425

2526
@classmethod
27+
@singleton_factory()
2628
def from_config(
2729
cls, config_factory: InternetRetrieverConfigFactory, embedder: BaseEmbedder
2830
) -> InternetGoogleRetriever | None:

src/memos/memos_tools/singleton.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Singleton decorator module for caching factory instances to avoid excessive memory usage
3+
from repeated initialization.
4+
"""
5+
6+
import hashlib
7+
import json
8+
9+
from collections.abc import Callable
10+
from functools import wraps
11+
from typing import Any, TypeVar
12+
from weakref import WeakValueDictionary
13+
14+
15+
T = TypeVar("T")
16+
17+
18+
class FactorySingleton:
19+
"""Factory singleton manager that caches instances based on configuration parameters"""
20+
21+
def __init__(self):
22+
# Use weak reference dictionary for automatic cleanup when instances are no longer referenced
23+
self._instances: dict[str, WeakValueDictionary] = {}
24+
25+
def _generate_cache_key(self, config: Any, *args, **kwargs) -> str:
26+
"""Generate cache key based on configuration only (ignoring other parameters)"""
27+
28+
# Handle configuration objects - only use the config parameter
29+
if hasattr(config, "model_dump"): # Pydantic model
30+
config_data = config.model_dump()
31+
elif hasattr(config, "dict"): # Legacy Pydantic model
32+
config_data = config.dict()
33+
elif isinstance(config, dict):
34+
config_data = config
35+
else:
36+
# For other types, try to convert to string
37+
config_data = str(config)
38+
39+
# Filter out time-related fields that shouldn't affect caching
40+
filtered_config = self._filter_temporal_fields(config_data)
41+
42+
# Generate hash key based only on config
43+
try:
44+
cache_str = json.dumps(filtered_config, sort_keys=True, ensure_ascii=False, default=str)
45+
except (TypeError, ValueError):
46+
# If JSON serialization fails, convert the entire config to string
47+
cache_str = str(filtered_config)
48+
49+
return hashlib.md5(cache_str.encode("utf-8")).hexdigest()
50+
51+
def _filter_temporal_fields(self, config_data: Any) -> Any:
52+
"""Filter out temporal fields that shouldn't affect instance caching"""
53+
if isinstance(config_data, dict):
54+
filtered = {}
55+
for key, value in config_data.items():
56+
# Skip common temporal field names
57+
if key.lower() in {
58+
"created_at",
59+
"updated_at",
60+
"timestamp",
61+
"time",
62+
"date",
63+
"created_time",
64+
"updated_time",
65+
"last_modified",
66+
"modified_at",
67+
"start_time",
68+
"end_time",
69+
"execution_time",
70+
"run_time",
71+
}:
72+
continue
73+
# Recursively filter nested dictionaries
74+
filtered[key] = self._filter_temporal_fields(value)
75+
return filtered
76+
elif isinstance(config_data, list):
77+
# Recursively filter lists
78+
return [self._filter_temporal_fields(item) for item in config_data]
79+
else:
80+
# For primitive types, return as-is
81+
return config_data
82+
83+
def get_or_create(self, factory_class: type, cache_key: str, creator_func: Callable) -> Any:
84+
"""Get or create instance"""
85+
class_name = factory_class.__name__
86+
87+
if class_name not in self._instances:
88+
self._instances[class_name] = WeakValueDictionary()
89+
90+
class_cache = self._instances[class_name]
91+
92+
if cache_key in class_cache:
93+
return class_cache[cache_key]
94+
95+
# Create new instance
96+
instance = creator_func()
97+
class_cache[cache_key] = instance
98+
return instance
99+
100+
def clear_cache(self, factory_class: type | None = None):
101+
"""Clear cache"""
102+
if factory_class:
103+
class_name = factory_class.__name__
104+
if class_name in self._instances:
105+
self._instances[class_name].clear()
106+
else:
107+
for cache in self._instances.values():
108+
cache.clear()
109+
110+
111+
# Global singleton manager
112+
_factory_singleton = FactorySingleton()
113+
114+
115+
def singleton_factory(factory_class: type | str | None = None):
116+
"""
117+
Factory singleton decorator
118+
119+
Usage:
120+
@singleton_factory()
121+
def from_config(cls, config):
122+
return SomeClass(config)
123+
124+
Or specify factory class:
125+
@singleton_factory(EmbedderFactory)
126+
def from_config(cls, config):
127+
return SomeClass(config)
128+
"""
129+
130+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
131+
@wraps(func)
132+
def wrapper(*args, **kwargs) -> T:
133+
# Determine factory class and config parameter
134+
target_factory_class = factory_class
135+
config = None
136+
137+
# Simple logic: check if first parameter is a class or config
138+
if args:
139+
if hasattr(args[0], "__name__") and hasattr(args[0], "__module__"):
140+
# First parameter is a class (cls), so this is a @classmethod
141+
if target_factory_class is None:
142+
target_factory_class = args[0]
143+
config = args[1] if len(args) > 1 else None
144+
else:
145+
# First parameter is config, so this is a @staticmethod
146+
if target_factory_class is None:
147+
raise ValueError(
148+
"Factory class must be explicitly specified for static methods"
149+
)
150+
if isinstance(target_factory_class, str):
151+
# Convert string to a mock class for caching purposes
152+
class MockFactoryClass:
153+
__name__ = target_factory_class
154+
155+
target_factory_class = MockFactoryClass
156+
config = args[0]
157+
158+
if config is None:
159+
# If no configuration parameter, call original function directly
160+
return func(*args, **kwargs)
161+
162+
# Generate cache key based only on config
163+
cache_key = _factory_singleton._generate_cache_key(config)
164+
165+
# Function to create instance
166+
def creator():
167+
return func(*args, **kwargs)
168+
169+
# Get or create instance
170+
return _factory_singleton.get_or_create(target_factory_class, cache_key, creator)
171+
172+
return wrapper
173+
174+
return decorator

src/memos/parsers/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, ClassVar
22

33
from memos.configs.parser import ParserConfigFactory
4+
from memos.memos_tools.singleton import singleton_factory
45
from memos.parsers.base import BaseParser
56
from memos.parsers.markitdown import MarkItDownParser
67

@@ -11,6 +12,7 @@ class ParserFactory(BaseParser):
1112
backend_to_class: ClassVar[dict[str, Any]] = {"markitdown": MarkItDownParser}
1213

1314
@classmethod
15+
@singleton_factory()
1416
def from_config(cls, config_factory: ParserConfigFactory) -> BaseParser:
1517
backend = config_factory.backend
1618
if backend not in cls.backend_to_class:

src/memos/reranker/factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from typing import TYPE_CHECKING, Any
55

6+
# Import singleton decorator
7+
from memos.memos_tools.singleton import singleton_factory
8+
69
from .cosine_local import CosineLocalReranker
710
from .http_bge import HTTPBGEReranker
811
from .noop import NoopReranker
@@ -16,6 +19,7 @@
1619

1720
class RerankerFactory:
1821
@staticmethod
22+
@singleton_factory("RerankerFactory")
1923
def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
2024
if not cfg:
2125
return None

0 commit comments

Comments
 (0)