Skip to content

Commit 5390e2d

Browse files
committed
build: optimize dependencies management
1 parent b4d5a88 commit 5390e2d

File tree

16 files changed

+573
-451
lines changed

16 files changed

+573
-451
lines changed

poetry.lock

Lines changed: 317 additions & 367 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,58 @@ keywords = ["memory", "llm", "language model", "memoryOS", "agent"]
1212
packages = [{include = "memos", from = "src"}]
1313

1414
[tool.poetry.dependencies]
15+
16+
# Core dependencies. These should not be optional.
1517
python = "^3.10"
1618
openai = "^1.77.0"
1719
ollama = "^0.4.8"
18-
qdrant-client = "^1.14.2"
20+
tenacity = "^9.1.2" # Error handling and retrying library
1921
transformers = "^4.51.3"
20-
markitdown = {extras = ["docx", "pdf", "pptx", "xls", "xlsx"], version = "^0.1.1"}
21-
chonkie = "^1.0.7"
22-
tenacity = "^9.1.2"
23-
neo4j = "^5.28.1"
24-
accelerate = "^1.7.0"
2522
fastapi = {extras = ["all"], version = "^0.115.12"}
26-
sentence-transformers = "^4.1.0"
27-
sqlalchemy = "^2.0.41"
28-
redis = "^6.2.0"
29-
pika = "^1.3.2"
30-
schedule = "^1.2.2"
31-
volcengine-python-sdk = "^4.0.4"
23+
sqlalchemy = "^2.0.41" # SQL toolkit
24+
25+
# GeneralTextualMemory dependencies
26+
qdrant-client = {version = "^1.14.2", optional = true} # Vector database
27+
28+
# TreeTextualMemory dependencies
29+
neo4j = {version = "^5.28.1", optional = true} # Graph database
30+
schedule = {version = "^1.2.2", optional = true} # Task scheduling library
31+
scikit-learn = {version = "^1.7.0", optional = true} # Machine learning library
32+
33+
# MemScheduler dependencies
34+
redis = {version = "^6.2.0", optional = true} # Key-value store
35+
pika = {version = "^1.3.2", optional = true} # RabbitMQ client for Python
36+
37+
[tool.poetry.extras]
38+
general-mem = ["qdrant-client"]
39+
tree-mem = ["neo4j", "schedule", "scikit-learn"]
40+
mem-scheduler = ["redis", "pika"]
41+
all = [
42+
"qdrant-client",
43+
"neo4j",
44+
"schedule",
45+
"scikit-learn",
46+
"redis",
47+
"pika",
48+
]
3249

3350
[tool.poetry.group.dev]
3451
optional = false
3552

3653
[tool.poetry.group.dev.dependencies]
54+
# Core development dependencies
3755
pre-commit = "^4.2.0"
3856
ruff = "^0.11.8"
3957

58+
# Hard-to-avoid dependencies
59+
torch = "^2.0.0" # CPU-version of PyTorch, only used for testing
60+
61+
# Infrequently used dependencies
62+
markitdown = {extras = ["docx", "pdf", "pptx", "xls", "xlsx"], version = "^0.1.1"} # MarkItDown parser for various file formats
63+
chonkie = "^1.0.7" # Sentence chunking library
64+
sentence-transformers = "^4.1.0" # Text Embedding
65+
volcengine-python-sdk = "^4.0.4" # ByteDance's AI Service SDK (namely, Volcano Ark)
66+
4067
[tool.poetry.group.test]
4168
optional = false
4269

src/memos/chunkers/sentence_chunker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from chonkie import SentenceChunker as ChonkieSentenceChunker
2-
31
from memos.configs.chunker import SentenceChunkerConfig
2+
from memos.dependency import require_python_package
43
from memos.log import get_logger
54

65
from .base import BaseChunker, Chunk
@@ -12,7 +11,14 @@
1211
class SentenceChunker(BaseChunker):
1312
"""Sentence-based text chunker."""
1413

14+
@require_python_package(
15+
import_name="chonkie",
16+
install_command="pip install chonkie",
17+
install_link="https://docs.chonkie.ai/python-sdk/getting-started/installation",
18+
)
1519
def __init__(self, config: SentenceChunkerConfig):
20+
from chonkie import SentenceChunker as ChonkieSentenceChunker
21+
1622
self.config = config
1723
self.chunker = ChonkieSentenceChunker(
1824
tokenizer_or_token_counter=config.tokenizer_or_token_counter,

src/memos/dependency.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
This utility provides tools for managing dependencies in MemOS.
3+
"""
4+
5+
import functools
6+
import importlib
7+
8+
9+
def require_python_package(
10+
import_name: str, install_command: str | None = None, install_link: str | None = None
11+
):
12+
"""Check if a package is available and provide installation hints on import failure.
13+
14+
Args:
15+
import_name (str): The top-level importable module name a package provides.
16+
install_command (str, optional): Installation command.
17+
install_link (str, optional): URL link to installation guide.
18+
19+
Returns:
20+
Callable: A decorator function that wraps the target function with package availability check.
21+
22+
Raises:
23+
ImportError: When the specified package is not available, with installation
24+
instructions included in the error message.
25+
26+
Example:
27+
>>> @require_python_package(
28+
... import_name='faiss',
29+
... install_command='pip install faiss-cpu',
30+
... install_link='https://github.com/facebookresearch/faiss/blob/main/INSTALL.md'
31+
... )
32+
... def create_faiss_index():
33+
... from faiss import IndexFlatL2 # Actual import in function
34+
... return IndexFlatL2(128)
35+
"""
36+
37+
def decorator(func):
38+
@functools.wraps(func)
39+
def wrapper(*args, **kwargs):
40+
try:
41+
importlib.import_module(import_name)
42+
except ImportError:
43+
error_msg = f"Missing required module - '{import_name}'\n"
44+
error_msg += f"💡 Install command: {install_command}\n" if install_command else ""
45+
error_msg += f"💡 Install guide: {install_link}\n" if install_link else ""
46+
47+
raise ImportError(error_msg) from None
48+
return func(*args, **kwargs)
49+
50+
return wrapper
51+
52+
return decorator

src/memos/embedders/ark.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
1-
from volcenginesdkarkruntime import Ark
2-
from volcenginesdkarkruntime.types.multimodal_embedding import (
3-
EmbeddingInputParam,
4-
MultimodalEmbeddingContentPartTextParam,
5-
MultimodalEmbeddingResponse,
6-
)
7-
81
from memos.configs.embedder import ArkEmbedderConfig
2+
from memos.dependency import require_python_package
93
from memos.embedders.base import BaseEmbedder
104
from memos.log import get_logger
115

@@ -16,7 +10,14 @@
1610
class ArkEmbedder(BaseEmbedder):
1711
"""Ark Embedder class."""
1812

13+
@require_python_package(
14+
import_name="volcenginesdkarkruntime",
15+
install_command="pip install 'volcengine-python-sdk[ark]'",
16+
install_link="https://www.volcengine.com/docs/82379/1541595",
17+
)
1918
def __init__(self, config: ArkEmbedderConfig):
19+
from volcenginesdkarkruntime import Ark
20+
2021
self.config = config
2122

2223
if self.config.embedding_dims is not None:
@@ -44,6 +45,10 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4445
Returns:
4546
List of embeddings, each represented as a list of floats.
4647
"""
48+
from volcenginesdkarkruntime.types.multimodal_embedding import (
49+
MultimodalEmbeddingContentPartTextParam,
50+
)
51+
4752
if self.config.multi_modal:
4853
texts_input = [
4954
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
@@ -66,8 +71,12 @@ def text_embedding(self, inputs: list[str], chunk_size: int | None = None) -> li
6671
return embeddings
6772

6873
def multimodal_embeddings(
69-
self, inputs: list[EmbeddingInputParam], chunk_size: int | None = None
74+
self, inputs: list, chunk_size: int | None = None
7075
) -> list[list[float]]:
76+
from volcenginesdkarkruntime.types.multimodal_embedding import (
77+
MultimodalEmbeddingResponse, # noqa: TC002
78+
)
79+
7180
chunk_size_ = chunk_size or self.config.chunk_size
7281
embeddings: list[list[float]] = []
7382

src/memos/embedders/sentence_transformer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from sentence_transformers import SentenceTransformer
2-
31
from memos.configs.embedder import SenTranEmbedderConfig
2+
from memos.dependency import require_python_package
43
from memos.embedders.base import BaseEmbedder
54
from memos.log import get_logger
65

@@ -11,7 +10,14 @@
1110
class SenTranEmbedder(BaseEmbedder):
1211
"""Sentence Transformer Embedder class."""
1312

13+
@require_python_package(
14+
import_name="sentence_transformers",
15+
install_command="pip install sentence-transformers",
16+
install_link="https://www.sbert.net/docs/installation.html",
17+
)
1418
def __init__(self, config: SenTranEmbedderConfig):
19+
from sentence_transformers import SentenceTransformer
20+
1521
self.config = config
1622
self.model = SentenceTransformer(
1723
self.config.model_name_or_path, trust_remote_code=self.config.trust_remote_code

src/memos/graph_dbs/neo4j.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
from datetime import datetime
44
from typing import Any, Literal
55

6-
from neo4j import GraphDatabase
7-
from neo4j.exceptions import ClientError
8-
96
from memos.configs.graph_db import Neo4jGraphDBConfig
7+
from memos.dependency import require_python_package
108
from memos.graph_dbs.base import BaseGraphDB
119
from memos.log import get_logger
1210

@@ -57,6 +55,11 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
5755
class Neo4jGraphDB(BaseGraphDB):
5856
"""Neo4j-based implementation of a graph memory store."""
5957

58+
@require_python_package(
59+
import_name="neo4j",
60+
install_command="pip install neo4j",
61+
install_link="https://neo4j.com/docs/python-manual/current/install/",
62+
)
6063
def __init__(self, config: Neo4jGraphDBConfig):
6164
"""Neo4j-based implementation of a graph memory store.
6265
@@ -75,6 +78,7 @@ def __init__(self, config: Neo4jGraphDBConfig):
7578
All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
7679
but it will be removed automatically before returning to external consumers.
7780
"""
81+
from neo4j import GraphDatabase
7882

7983
self.config = config
8084
self.driver = GraphDatabase.driver(config.uri, auth=(config.user, config.password))
@@ -994,6 +998,8 @@ def drop_database(self) -> None:
994998
)
995999

9961000
def _ensure_database_exists(self):
1001+
from neo4j.exceptions import ClientError
1002+
9971003
try:
9981004
with self.driver.session(database="system") as session:
9991005
session.run(f"CREATE DATABASE `{self.db_name}` IF NOT EXISTS")

src/memos/llms/hf.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Generator
2-
3-
import torch
2+
from typing import Any
43

54
from transformers import (
65
AutoModelForCausalLM,
@@ -134,6 +133,8 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
134133
Yields:
135134
str: Streaming response chunks.
136135
"""
136+
import torch
137+
137138
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
138139

139140
# Get generation parameters
@@ -200,6 +201,8 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
200201
Returns:
201202
str: Model response.
202203
"""
204+
import torch
205+
203206
query_ids = self.tokenizer(
204207
query, return_tensors="pt", add_special_tokens=False
205208
).input_ids.to(self.model.device)
@@ -287,10 +290,7 @@ def _generate_with_cache_stream(
287290

288291
generated.append(next_token)
289292

290-
@torch.no_grad()
291-
def _prefill(
292-
self, input_ids: torch.Tensor, kv: DynamicCache
293-
) -> tuple[torch.Tensor, DynamicCache]:
293+
def _prefill(self, input_ids: Any, kv: DynamicCache) -> tuple[Any, DynamicCache]:
294294
"""
295295
Forward the model once, returning last-step logits and updated KV cache.
296296
Args:
@@ -299,22 +299,27 @@ def _prefill(
299299
Returns:
300300
tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
301301
"""
302-
out = self.model(
303-
input_ids=input_ids,
304-
use_cache=True,
305-
past_key_values=kv,
306-
return_dict=True,
307-
)
302+
import torch
303+
304+
with torch.no_grad():
305+
out = self.model(
306+
input_ids=input_ids,
307+
use_cache=True,
308+
past_key_values=kv,
309+
return_dict=True,
310+
)
308311
return out.logits[:, -1, :], out.past_key_values
309312

310-
def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor:
313+
def _select_next_token(self, logits: Any) -> Any:
311314
"""
312315
Select the next token from logits using sampling or argmax, depending on config.
313316
Args:
314317
logits (torch.Tensor): Logits for the next token.
315318
Returns:
316319
torch.Tensor: Selected token ID(s).
317320
"""
321+
import torch
322+
318323
if getattr(self.config, "do_sample", True):
319324
batch_size, _ = logits.size()
320325
dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device)
@@ -323,7 +328,7 @@ def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor:
323328
return torch.multinomial(probs, num_samples=1)
324329
return torch.argmax(logits, dim=-1, keepdim=True)
325330

326-
def _should_stop(self, token: torch.Tensor) -> bool:
331+
def _should_stop(self, token: Any) -> bool:
327332
"""
328333
Check if the given token is the EOS (end-of-sequence) token.
329334
Args:
@@ -347,6 +352,8 @@ def build_kv_cache(self, messages) -> DynamicCache:
347352
Returns:
348353
DynamicCache: The constructed KV cache object.
349354
"""
355+
import torch
356+
350357
# Accept multiple input types and convert to standard chat messages
351358
if isinstance(messages, str):
352359
messages = [

0 commit comments

Comments
 (0)