Skip to content

Commit 7426969

Browse files
chore: apply ruff linting fixes and type annotations to memory module
Co-authored-by: Lorenze Jay <[email protected]>
1 parent d879be8 commit 7426969

File tree

9 files changed

+52
-46
lines changed

9 files changed

+52
-46
lines changed

src/crewai/memory/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .entity.entity_memory import EntityMemory
2+
from .external.external_memory import ExternalMemory
23
from .long_term.long_term_memory import LongTermMemory
34
from .short_term.short_term_memory import ShortTermMemory
4-
from .external.external_memory import ExternalMemory
55

66
__all__ = [
77
"EntityMemory",
8+
"ExternalMemory",
89
"LongTermMemory",
910
"ShortTermMemory",
10-
"ExternalMemory",
1111
]

src/crewai/memory/external/external_memory_item.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any
22

33

44
class ExternalMemoryItem:
55
def __init__(
66
self,
77
value: Any,
8-
metadata: Optional[Dict[str, Any]] = None,
9-
agent: Optional[str] = None,
8+
metadata: dict[str, Any] | None = None,
9+
agent: str | None = None,
1010
):
1111
self.value = value
1212
self.metadata = metadata

src/crewai/memory/long_term/long_term_memory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Any, Dict, List
21
import time
2+
from typing import Any
33

4-
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
5-
from crewai.memory.memory import Memory
64
from crewai.events.event_bus import crewai_event_bus
75
from crewai.events.types.memory_events import (
8-
MemoryQueryStartedEvent,
96
MemoryQueryCompletedEvent,
107
MemoryQueryFailedEvent,
11-
MemorySaveStartedEvent,
8+
MemoryQueryStartedEvent,
129
MemorySaveCompletedEvent,
1310
MemorySaveFailedEvent,
11+
MemorySaveStartedEvent,
1412
)
13+
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
14+
from crewai.memory.memory import Memory
1515
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
1616

1717

@@ -84,7 +84,7 @@ def search( # type: ignore # signature of "search" incompatible with supertype
8484
self,
8585
task: str,
8686
latest_n: int = 3,
87-
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
87+
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
8888
crewai_event_bus.emit(
8989
self,
9090
event=MemoryQueryStartedEvent(

src/crewai/memory/long_term/long_term_memory_item.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional, Union
1+
from typing import Any
22

33

44
class LongTermMemoryItem:
@@ -8,8 +8,8 @@ def __init__(
88
task: str,
99
expected_output: str,
1010
datetime: str,
11-
quality: Optional[Union[int, float]] = None,
12-
metadata: Optional[Dict[str, Any]] = None,
11+
quality: int | float | None = None,
12+
metadata: dict[str, Any] | None = None,
1313
):
1414
self.task = task
1515
self.agent = agent

src/crewai/memory/short_term/short_term_memory_item.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any
22

33

44
class ShortTermMemoryItem:
55
def __init__(
66
self,
77
data: Any,
8-
agent: Optional[str] = None,
9-
metadata: Optional[Dict[str, Any]] = None,
8+
agent: str | None = None,
9+
metadata: dict[str, Any] | None = None,
1010
):
1111
self.data = data
1212
self.agent = agent

src/crewai/memory/storage/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from typing import Any, Dict, List
1+
from typing import Any
22

33

44
class Storage:
55
"""Abstract base class defining the storage interface"""
66

7-
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
7+
def save(self, value: Any, metadata: dict[str, Any]) -> None:
88
pass
99

1010
def search(
1111
self, query: str, limit: int, score_threshold: float
12-
) -> Dict[str, Any] | List[Any]:
12+
) -> dict[str, Any] | list[Any]:
1313
return {}
1414

1515
def reset(self) -> None:

src/crewai/memory/storage/kickoff_task_outputs_storage.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import sqlite3
44
from pathlib import Path
5-
from typing import Any, Dict, List, Optional
5+
from typing import Any
66

77
from crewai.task import Task
88
from crewai.utilities import Printer
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
1818
An updated SQLite storage class for kickoff task outputs storage.
1919
"""
2020

21-
def __init__(self, db_path: Optional[str] = None) -> None:
21+
def __init__(self, db_path: str | None = None) -> None:
2222
if db_path is None:
2323
# Get the parent directory of the default db path and create our db file there
2424
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
@@ -57,15 +57,15 @@ def _initialize_db(self) -> None:
5757
except sqlite3.Error as e:
5858
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
5959
logger.error(error_msg)
60-
raise DatabaseOperationError(error_msg, e)
60+
raise DatabaseOperationError(error_msg, e) from e
6161

6262
def add(
6363
self,
6464
task: Task,
65-
output: Dict[str, Any],
65+
output: dict[str, Any],
6666
task_index: int,
6767
was_replayed: bool = False,
68-
inputs: Dict[str, Any] | None = None,
68+
inputs: dict[str, Any] | None = None,
6969
) -> None:
7070
"""Add a new task output record to the database.
7171
@@ -103,7 +103,7 @@ def add(
103103
except sqlite3.Error as e:
104104
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
105105
logger.error(error_msg)
106-
raise DatabaseOperationError(error_msg, e)
106+
raise DatabaseOperationError(error_msg, e) from e
107107

108108
def update(
109109
self,
@@ -138,7 +138,7 @@ def update(
138138
else value
139139
)
140140

141-
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
141+
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec # noqa: S608
142142
values.append(task_index)
143143

144144
cursor.execute(query, tuple(values))
@@ -151,9 +151,9 @@ def update(
151151
except sqlite3.Error as e:
152152
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
153153
logger.error(error_msg)
154-
raise DatabaseOperationError(error_msg, e)
154+
raise DatabaseOperationError(error_msg, e) from e
155155

156-
def load(self) -> List[Dict[str, Any]]:
156+
def load(self) -> list[dict[str, Any]]:
157157
"""Load all task output records from the database.
158158
159159
Returns:
@@ -192,7 +192,7 @@ def load(self) -> List[Dict[str, Any]]:
192192
except sqlite3.Error as e:
193193
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
194194
logger.error(error_msg)
195-
raise DatabaseOperationError(error_msg, e)
195+
raise DatabaseOperationError(error_msg, e) from e
196196

197197
def delete_all(self) -> None:
198198
"""Delete all task output records from the database.
@@ -212,4 +212,4 @@ def delete_all(self) -> None:
212212
except sqlite3.Error as e:
213213
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
214214
logger.error(error_msg)
215-
raise DatabaseOperationError(error_msg, e)
215+
raise DatabaseOperationError(error_msg, e) from e

src/crewai/memory/storage/ltm_sqlite_storage.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import sqlite3
33
from pathlib import Path
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any
55

66
from crewai.utilities import Printer
77
from crewai.utilities.paths import db_storage_path
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
1212
An updated SQLite storage class for LTM data storage.
1313
"""
1414

15-
def __init__(
16-
self, db_path: Optional[str] = None
17-
) -> None:
15+
def __init__(self, db_path: str | None = None) -> None:
1816
if db_path is None:
1917
# Get the parent directory of the default db path and create our db file there
2018
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
@@ -53,9 +51,9 @@ def _initialize_db(self):
5351
def save(
5452
self,
5553
task_description: str,
56-
metadata: Dict[str, Any],
54+
metadata: dict[str, Any],
5755
datetime: str,
58-
score: Union[int, float],
56+
score: int | float,
5957
) -> None:
6058
"""Saves data to the LTM table with error handling."""
6159
try:
@@ -75,9 +73,7 @@ def save(
7573
color="red",
7674
)
7775

78-
def load(
79-
self, task_description: str, latest_n: int
80-
) -> Optional[List[Dict[str, Any]]]:
76+
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
8177
"""Queries the LTM table by task description with error handling."""
8278
try:
8379
with sqlite3.connect(self.db_path) as conn:
@@ -89,7 +85,7 @@ def load(
8985
WHERE task_description = ?
9086
ORDER BY datetime DESC, score ASC
9187
LIMIT {latest_n}
92-
""", # nosec
88+
""", # nosec # noqa: S608
9389
(task_description,),
9490
)
9591
rows = cursor.fetchall()
@@ -125,4 +121,4 @@ def reset(
125121
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
126122
color="red",
127123
)
128-
return None
124+
return

src/crewai/memory/storage/rag_storage.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22
import traceback
33
import warnings
4-
from typing import Any
4+
from typing import Any, cast
55

66
from crewai.rag.chromadb.config import ChromaDBConfig
7+
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
78
from crewai.rag.config.utils import get_rag_client
89
from crewai.rag.core.base_client import BaseClient
910
from crewai.rag.embeddings.factory import get_embedding_function
@@ -21,8 +22,13 @@ class RAGStorage(BaseRAGStorage):
2122
"""
2223

2324
def __init__(
24-
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
25-
):
25+
self,
26+
type: str,
27+
allow_reset: bool = True,
28+
embedder_config: dict[str, Any] | None = None,
29+
crew: Any = None,
30+
path: str | None = None,
31+
) -> None:
2632
super().__init__(type, allow_reset, embedder_config, crew)
2733
agents = crew.agents if crew else []
2834
agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -44,7 +50,11 @@ def __init__(
4450

4551
if self.embedder_config:
4652
embedding_function = get_embedding_function(self.embedder_config)
47-
config = ChromaDBConfig(embedding_function=embedding_function)
53+
config = ChromaDBConfig(
54+
embedding_function=cast(
55+
ChromaEmbeddingFunctionWrapper, embedding_function
56+
)
57+
)
4858
self._client = create_client(config)
4959

5060
def _get_client(self) -> BaseClient:

0 commit comments

Comments
 (0)