Skip to content

Commit f9e712e

Browse files
feat(brains)
1 parent 69abbe4 commit f9e712e

File tree

23 files changed

+414
-169
lines changed

23 files changed

+414
-169
lines changed

src/adapters/cache.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
-----
99
"""
1010

11+
from typing import Optional
1112
from .interfaces.cache import CacheClient
1213

1314

@@ -25,20 +26,26 @@ def add_client(self, client: CacheClient) -> None:
2526
"""
2627
self.cache = client
2728

28-
def get(self, key: str) -> str:
29+
def get(self, key: str, brain_id: str = "default") -> str:
2930
"""
3031
Get a value from the cache.
3132
"""
32-
return self.cache.get(key)
33+
return self.cache.get(key, brain_id)
3334

34-
def set(self, key: str, value: str, expires_in: int) -> bool:
35+
def set(
36+
self,
37+
key: str,
38+
value: str,
39+
brain_id: str = "default",
40+
expires_in: Optional[int] = None,
41+
) -> bool:
3542
"""
3643
Set a value in the cache with an expiration time.
3744
"""
38-
return self.cache.set(key, value, expires_in)
45+
return self.cache.set(key, value, brain_id, expires_in)
3946

40-
def delete(self, key: str) -> bool:
47+
def delete(self, key: str, brain_id: str = "default") -> bool:
4148
"""
4249
Delete a value from the cache.
4350
"""
44-
return self.cache.delete(key)
51+
return self.cache.delete(key, brain_id)

src/adapters/data.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing import List, Tuple
1212
from src.adapters.interfaces.data import DataClient, SearchResult
13-
from src.constants.data import Observation, StructuredData, TextChunk
13+
from src.constants.data import Brain, Observation, StructuredData, TextChunk
1414

1515

1616
class DataAdapter:
@@ -27,34 +27,52 @@ def add_client(self, client: DataClient) -> None:
2727
"""
2828
self.data = client
2929

30-
def save_text_chunk(self, text_chunk: TextChunk) -> TextChunk:
30+
def save_text_chunk(
31+
self, text_chunk: TextChunk, brain_id: str = "default"
32+
) -> TextChunk:
3133
"""
3234
Save a text chunk to the data client.
3335
"""
34-
return self.data.save_text_chunk(text_chunk)
36+
return self.data.save_text_chunk(text_chunk, brain_id)
3537

36-
def save_observations(self, observations: List[Observation]) -> Observation:
38+
def save_observations(
39+
self, observations: List[Observation], brain_id: str = "default"
40+
) -> Observation:
3741
"""
3842
Save a list of observations to the data client.
3943
"""
40-
return self.data.save_observations(observations)
44+
return self.data.save_observations(observations, brain_id)
4145

42-
def search(self, text: str) -> SearchResult:
46+
def search(self, text: str, brain_id: str = "default") -> SearchResult:
4347
"""
4448
Search data by text and return a list of text chunks and observations.
4549
"""
46-
return self.data.search(text)
50+
return self.data.search(text, brain_id)
4751

4852
def get_text_chunks_by_ids(
49-
self, ids: List[str], with_observations: bool
53+
self, ids: List[str], with_observations: bool, brain_id: str = "default"
5054
) -> Tuple[List[TextChunk], List[Observation]]:
5155
"""
5256
Get data by their IDs.
5357
"""
54-
return self.data.get_text_chunks_by_ids(ids, with_observations)
58+
return self.data.get_text_chunks_by_ids(ids, with_observations, brain_id)
5559

56-
def save_structured_data(self, structured_data: StructuredData) -> StructuredData:
60+
def save_structured_data(
61+
self, structured_data: StructuredData, brain_id: str = "default"
62+
) -> StructuredData:
5763
"""
5864
Save a structured data to the data client.
5965
"""
60-
return self.data.save_structured_data(structured_data)
66+
return self.data.save_structured_data(structured_data, brain_id)
67+
68+
def create_brain(self, name_key: str) -> Brain:
69+
"""
70+
Create a new brain in the data client.
71+
"""
72+
return self.data.create_brain(name_key)
73+
74+
def get_brain(self, name_key: str) -> Brain:
75+
"""
76+
Get a brain from the data client.
77+
"""
78+
return self.data.get_brain(name_key)

src/adapters/embeddings.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,32 +41,43 @@ def add_client(self, client: VectorStoreClient) -> None:
4141
"""
4242
self.vector_store = client
4343

44-
def add_vectors(self, vectors: list[Vector], store: str) -> None:
44+
def add_vectors(
45+
self, vectors: list[Vector], store: str, brain_id: str = "default"
46+
) -> None:
4547
"""
4648
Add vectors to the vector store.
4749
"""
48-
return self.vector_store.add_vectors(vectors, store)
50+
return self.vector_store.add_vectors(vectors, store, brain_id)
4951

50-
def search_vectors(self, query: str, store: str, k: int = 10) -> list[Vector]:
52+
def search_vectors(
53+
self, query: str, store: str, brain_id: str = "default", k: int = 10
54+
) -> list[Vector]:
5155
"""
5256
Search vectors in the vector store and return the top k vectors.
5357
"""
54-
return self.vector_store.search_vectors(query, store, k)
58+
return self.vector_store.search_vectors(query, store, brain_id, k)
5559

56-
def get_by_ids(self, ids: list[str], store: str) -> list[Vector]:
60+
def get_by_ids(
61+
self, ids: list[str], store: str, brain_id: str = "default"
62+
) -> list[Vector]:
5763
"""
5864
Get vectors by their IDs.
5965
"""
60-
return self.vector_store.get_by_ids(ids, store)
66+
return self.vector_store.get_by_ids(ids, store, brain_id)
6167

6268
def search_similar_by_ids(
63-
self, vector_ids: list[str], store: str, min_similarity: float, limit: int = 10
69+
self,
70+
vector_ids: list[str],
71+
store: str,
72+
min_similarity: float,
73+
limit: int = 10,
74+
brain_id: str = "default",
6475
) -> list[Vector]:
6576
"""
6677
Search similar vectors by their IDs.
6778
"""
6879
return self.vector_store.search_similar_by_ids(
69-
vector_ids, store, min_similarity, limit
80+
vector_ids, store, min_similarity, limit, brain_id
7081
)
7182

7283

src/adapters/graph.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Predicate,
1717
SearchEntitiesResult,
1818
SearchRelationshipsResult,
19-
Triple,
2019
)
2120

2221

@@ -50,57 +49,60 @@ def add_client(self, client: GraphClient) -> None:
5049
"""
5150
self.graph = client
5251

53-
def execute_operation(self, operation: str) -> str:
52+
def execute_operation(self, operation: str, brain_id: str = "default") -> str:
5453
"""
5554
Execute a generic graph operation.
5655
"""
5756
try:
58-
return self.graph.execute_operation(operation)
57+
return self.graph.execute_operation(operation, brain_id)
5958
except Exception as e: # pylint: disable=broad-exception-caught
6059
print(f"Error executing graph operation: {e} - {operation}")
6160
return f"Error executing graph operation: {e}"
6261

6362
def add_nodes(
6463
self,
6564
nodes: list[Node],
65+
brain_id: str = "default",
6666
identification_params: Optional[dict] = None,
6767
metadata: Optional[dict] = None,
68-
database: Optional[str] = None,
6968
) -> list[Node] | str:
7069
"""
7170
Add nodes to the graph.
7271
"""
73-
return self.graph.add_nodes(nodes, identification_params, metadata, database)
72+
return self.graph.add_nodes(nodes, brain_id, identification_params, metadata)
7473

7574
def add_relationship(
7675
self,
7776
subject: Node,
7877
predicate: Predicate,
7978
to_object: Node,
79+
brain_id: str = "default",
8080
) -> str:
8181
"""
8282
Add a relationship between two nodes to the graph.
8383
"""
84-
return self.graph.add_relationship(subject, predicate, to_object)
84+
return self.graph.add_relationship(subject, predicate, to_object, brain_id)
8585

8686
def search_graph(
8787
self,
8888
nodes: list[Node],
89+
brain_id: str = "default",
8990
) -> list[Node]:
9091
"""
9192
Search the graph for nodes and 1 degree relationships.
9293
"""
93-
return self.graph.search_graph(nodes)
94+
return self.graph.search_graph(nodes, brain_id)
9495

95-
def node_text_search(self, text: str) -> list[Node]:
96+
def node_text_search(self, text: str, brain_id: str = "default") -> list[Node]:
9697
"""
9798
Search the graph for nodes by partial text match into the name of the nodes.
9899
"""
99-
return self.graph.node_text_search(text)
100+
return self.graph.node_text_search(text, brain_id)
100101

101102
def get_nodes_by_uuid(
102103
self,
103104
uuids: list[str],
105+
brain_id: str = "default",
104106
with_relationships: Optional[bool] = False,
105107
relationships_depth: Optional[int] = 1,
106108
relationships_type: Optional[list[str]] = None,
@@ -111,80 +113,83 @@ def get_nodes_by_uuid(
111113
"""
112114
return self.graph.get_nodes_by_uuid(
113115
uuids,
116+
brain_id,
114117
with_relationships,
115118
relationships_depth,
116119
relationships_type,
117120
preferred_labels,
118121
)
119122

120-
def get_graph_entities(self) -> list[str]:
123+
def get_graph_entities(self, brain_id: str = "default") -> list[str]:
121124
"""
122125
Get the entities of the graph.
123126
"""
124-
return self.graph.get_graph_entities()
127+
return self.graph.get_graph_entities(brain_id)
125128

126-
def get_graph_relationships(self) -> list[str]:
129+
def get_graph_relationships(self, brain_id: str = "default") -> list[str]:
127130
"""
128131
Get the relationships of the graph.
129132
"""
130-
return self.graph.get_graph_relationships()
133+
return self.graph.get_graph_relationships(brain_id)
131134

132-
def get_graph_property_keys(self) -> list[str]:
135+
def get_graph_property_keys(self, brain_id: str = "default") -> list[str]:
133136
"""
134137
Get the property keys of the graph.
135138
"""
136-
return self.graph.get_graph_property_keys()
139+
return self.graph.get_graph_property_keys(brain_id)
137140

138-
def get_by_uuid(self, uuid: str) -> Node:
141+
def get_by_uuid(self, uuid: str, brain_id: str = "default") -> Node:
139142
"""
140143
Get a node by its UUID.
141144
"""
142-
return self.graph.get_by_uuid(uuid)
145+
return self.graph.get_by_uuid(uuid, brain_id)
143146

144-
def get_by_uuids(self, uuids: list[str]) -> list[Node]:
147+
def get_by_uuids(self, uuids: list[str], brain_id: str = "default") -> list[Node]:
145148
"""
146149
Get nodes by their UUIDs.
147150
"""
148-
return self.graph.get_by_uuids(uuids)
151+
return self.graph.get_by_uuids(uuids, brain_id)
149152

150153
def get_by_identification_params(
151154
self,
152155
identification_params: IdentificationParams,
156+
brain_id: str = "default",
153157
entity_types: Optional[list[str]] = None,
154158
) -> Node:
155159
"""
156160
Get a node by its identification params and entity types.
157161
"""
158162
return self.graph.get_by_identification_params(
159-
identification_params, entity_types
163+
identification_params, brain_id, entity_types
160164
)
161165

162166
def get_neighbors(
163-
self, node: Node, limit: int
167+
self, node: Node, limit: int, brain_id: str = "default"
164168
) -> list[Tuple[Node, Predicate, Node]]:
165169
"""
166170
Get the neighbors of a node.
167171
"""
168-
return self.graph.get_neighbors(node, limit)
172+
return self.graph.get_neighbors(node, limit, brain_id)
169173

170174
def get_node_with_rel_by_uuid(
171-
self, rel_ids_with_node_ids: list[tuple[str, str]]
175+
self, rel_ids_with_node_ids: list[tuple[str, str]], brain_id: str = "default"
172176
) -> list[dict]:
173177
"""
174178
Get the node with the relationships by their UUIDs.
175179
"""
176-
return self.graph.get_node_with_rel_by_uuid(rel_ids_with_node_ids)
180+
return self.graph.get_node_with_rel_by_uuid(rel_ids_with_node_ids, brain_id)
177181

178182
def get_neighbor_node_tuples(
179-
self, a_uuid: str, b_uuids: list[str]
183+
self, a_uuid: str, b_uuids: list[str], brain_id: str = "default"
180184
) -> list[Tuple[Node, Predicate, Node]]:
181185
"""
182186
Get the neighbor node tuples by their UUIDs.
183187
"""
184-
return self.graph.get_neighbor_node_tuples(a_uuid, b_uuids)
188+
return self.graph.get_neighbor_node_tuples(a_uuid, b_uuids, brain_id)
185189

186190
def get_connected_nodes(
187191
self,
192+
brain_id: str = "default",
188193
node: Optional[Node] = None,
189194
uuids: Optional[list[str]] = None,
190195
limit: Optional[int] = 10,
@@ -194,11 +199,12 @@ def get_connected_nodes(
194199
Get the connected nodes by their UUIDs.
195200
"""
196201
return self.graph.get_connected_nodes(
197-
node=node, uuids=uuids, limit=limit, with_labels=with_labels
202+
brain_id, node=node, uuids=uuids, limit=limit, with_labels=with_labels
198203
)
199204

200205
def search_relationships(
201206
self,
207+
brain_id: str = "default",
202208
limit: int = 10,
203209
skip: int = 0,
204210
relationship_types: Optional[list[str]] = None,
@@ -213,6 +219,7 @@ def search_relationships(
213219
relationship_uuids = []
214220
# TODO: semantic search + src/core/agents/tools/kg_agent/KGAgentAddTripletsTool.py:165
215221
return self.graph.search_relationships(
222+
brain_id,
216223
limit,
217224
skip,
218225
relationship_types,
@@ -225,6 +232,7 @@ def search_relationships(
225232

226233
def search_entities(
227234
self,
235+
brain_id: str = "default",
228236
limit: int = 10,
229237
skip: int = 0,
230238
node_labels: Optional[list[str]] = None,
@@ -236,7 +244,7 @@ def search_entities(
236244
node_uuids = []
237245
# TODO: semantic search
238246
return self.graph.search_entities(
239-
limit, skip, node_labels, node_uuids, query_text
247+
brain_id, limit, skip, node_labels, node_uuids, query_text
240248
)
241249

242250

0 commit comments

Comments
 (0)