Skip to content

Commit e77a86a

Browse files
fix(auth + brains): handled multiple brains and single brain deployments
1 parent b7293d1 commit e77a86a

File tree

11 files changed

+176
-112
lines changed

11 files changed

+176
-112
lines changed

.env.example

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ BRAINPAT_TOKEN="your_token"
4646
# MultiBrain
4747
# Choose to allow or block automatic creation of new brains on requests with non existing new brain_ids
4848
BRAIN_CREATION_ALLOWED="true"
49-
# Choose whether to use for every brain the main pat or to use dedicated one for each brain
50-
USE_ONLY_SYSTEM_PAT="false"
49+
# Choose whether to fallback to default brain if not provided
50+
DEFAULT_BRAIN_FALLBACK="true"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "brainapi2"
3-
version = "1.6.6-dev"
3+
version = "1.6.14-dev"
44
description = "Version 1.x.x of the BrainAPI memory layer."
55
authors = [
66
{name = "Christian",email = "[email protected]"}

src/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class MilvusConfig:
7878

7979
def __init__(self):
8080
self.host = os.getenv("MILVUS_HOST")
81-
self.port = os.getenv("MILVUS_PORT")
81+
port_str = os.getenv("MILVUS_PORT")
82+
self.port = int(port_str) if port_str else None
8283
self.uri = os.getenv("MILVUS_URI")
8384
self.token = os.getenv("MILVUS_TOKEN")
8485
if [self.host, self.port].count(None) > 0 and [self.uri, self.token].count(

src/lib/milvus/client.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class MilvusClient(VectorStoreClient):
3838
def __init__(self):
3939
self._client = None
4040
self._lock = None
41+
self._pid = None
4142

4243
def _get_lock(self):
4344
if self._lock is None:
@@ -46,23 +47,38 @@ def _get_lock(self):
4647
self._lock = threading.Lock()
4748
return self._lock
4849

50+
def _reset_client_if_forked(self):
51+
import os as os_module
52+
53+
current_pid = os_module.getpid()
54+
if self._pid is not None and self._pid != current_pid:
55+
if self._client is not None:
56+
try:
57+
self._client.close()
58+
except Exception:
59+
pass
60+
self._client = None
61+
self._pid = current_pid
62+
4963
@property
5064
def client(self):
5165
"""
5266
Lazy initialization of the Milvus client.
5367
"""
68+
self._reset_client_if_forked()
5469
if self._client is None:
5570
with self._get_lock():
71+
self._reset_client_if_forked()
5672
if self._client is None:
5773
if config.milvus.uri and config.milvus.token:
5874
self._client = Milvus(
5975
uri=config.milvus.uri,
6076
token=config.milvus.token,
6177
)
6278
elif config.milvus.host and config.milvus.port:
79+
uri = f"http://{config.milvus.host}:{config.milvus.port}"
6380
self._client = Milvus(
64-
host=config.milvus.host,
65-
port=config.milvus.port,
81+
uri=uri,
6682
token=config.milvus.token if config.milvus.token else None,
6783
)
6884
else:

src/lib/neo4j/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def add_relationship(
230230
CREATE (a)-[:{":".join(self._clean_labels([predicate.name]))} {{ Description: '{safe_desc}' }}]->(b)
231231
RETURN a, b
232232
"""
233+
self.ensure_database(brain_id)
233234
result = self.driver.execute_query(cypher_query, database_=brain_id)
234235
return result
235236

@@ -248,6 +249,7 @@ def search_graph(self, nodes: list[Node], brain_id: str) -> list[Node]:
248249
queries.append(query)
249250

250251
cypher_query = " UNION ".join(queries)
252+
self.ensure_database(brain_id)
251253
result = self.driver.execute_query(cypher_query, database_=brain_id)
252254
return result
253255

@@ -260,6 +262,7 @@ def node_text_search(self, text: str, brain_id: str) -> list[Node]:
260262
WHERE toLower(n.Name) CONTAINS toLower('{text}')
261263
RETURN n
262264
"""
265+
self.ensure_database(brain_id)
263266
result = self.driver.execute_query(cypher_query, database_=brain_id)
264267
return [
265268
Node(
@@ -315,6 +318,7 @@ def get_nodes_by_uuid(
315318
if with_relationships:
316319
cypher_query += ", r, m.uuid as m_uuid, m.name as m_name, labels(m) as m_labels, m.description as m_description, properties(m) as m_properties"
317320

321+
self.ensure_database(brain_id)
318322
result = self.driver.execute_query(cypher_query, database_=brain_id)
319323

320324
if with_relationships:
@@ -362,6 +366,7 @@ def get_graph_entities(self, brain_id: str) -> list[str]:
362366
MATCH (n)
363367
RETURN DISTINCT labels(n) as labels
364368
"""
369+
self.ensure_database(brain_id)
365370
result = self.driver.execute_query(cypher_query, database_=brain_id)
366371
return [label for record in result.records for label in record["labels"]]
367372

@@ -373,6 +378,7 @@ def get_graph_relationships(self, brain_id: str) -> list[str]:
373378
CALL db.relationshipTypes() YIELD relationshipType
374379
RETURN relationshipType
375380
"""
381+
self.ensure_database(brain_id)
376382
result = self.driver.execute_query(cypher_query, database_=brain_id)
377383
return [record["relationshipType"] for record in result.records]
378384

@@ -384,6 +390,7 @@ def get_by_uuid(self, uuid: str, brain_id: str) -> Node:
384390
MATCH (n) WHERE n.uuid = '{uuid}'
385391
RETURN n.uuid as uuid, n.name as name, labels(n) as labels, n.description as description, properties(n) as properties
386392
"""
393+
self.ensure_database(brain_id)
387394
result = self.driver.execute_query(cypher_query, database_=brain_id)
388395
if not result.records or len(result.records) == 0:
389396
return None
@@ -403,6 +410,7 @@ def get_by_uuids(self, uuids: list[str], brain_id: str) -> list[Node]:
403410
MATCH (n) WHERE n.uuid IN ["{'","'.join(uuids)}"]
404411
RETURN n.uuid as uuid, n.name as name, labels(n) as labels, n.description as description, properties(n) as properties
405412
"""
413+
self.ensure_database(brain_id)
406414
result = self.driver.execute_query(cypher_query, database_=brain_id)
407415
return [
408416
Node(
@@ -438,6 +446,7 @@ def get_by_identification_params(
438446
MATCH (n{(":" + ":".join(self._clean_labels(entity_types))) if entity_types else ""}) {("WHERE " + " AND ".join(where_clauses)) if where_clauses else ""}
439447
RETURN n.uuid as uuid, n.name as name, labels(n) as labels, n.description as description, properties(n) as properties
440448
"""
449+
self.ensure_database(brain_id)
441450
result = self.driver.execute_query(cypher_query, database_=brain_id)
442451
if not result.records or len(result.records) == 0:
443452
return None
@@ -457,6 +466,7 @@ def get_graph_property_keys(self, brain_id: str) -> list[str]:
457466
CALL db.propertyKeys() YIELD propertyKey
458467
RETURN propertyKey
459468
"""
469+
self.ensure_database(brain_id)
460470
result = self.driver.execute_query(cypher_query, database_=brain_id)
461471
return [record["propertyKey"] for record in result.records]
462472

@@ -474,6 +484,7 @@ def get_neighbors(
474484
CASE WHEN startNode(r2) = c THEN 'out' ELSE 'in' END AS direction,
475485
c.uuid AS c_uuid, c.name AS c_name, labels(c) AS c_labels, c.description AS c_description, properties(c) AS c_properties
476486
"""
487+
self.ensure_database(brain_id)
477488
result = self.driver.execute_query(cypher_query, database_=brain_id)
478489

479490
neighbors = []
@@ -538,6 +549,7 @@ def get_neighbor_node_tuples(
538549
m.uuid as m_uuid, m.name as m_name, labels(m) as m_labels,
539550
m.description as m_description, properties(m) as m_properties, r as rel
540551
"""
552+
self.ensure_database(brain_id)
541553
result = self.driver.execute_query(cypher_query, database_=brain_id)
542554

543555
if not result.records:
@@ -632,6 +644,7 @@ def get_connected_nodes(
632644
n.uuid as n_uuid, n.name as n_name, labels(n) as n_labels, n.description as n_description, properties(n) as n_properties,
633645
CASE WHEN startNode(r) = n THEN 'out' ELSE 'in' END AS direction
634646
"""
647+
self.ensure_database(brain_id)
635648
result = self.driver.execute_query(cypher_query, database_=brain_id)
636649
return [
637650
(
@@ -727,6 +740,7 @@ def search_relationships(
727740
{"WHERE " + " AND ".join(filters) if filters else ""}
728741
RETURN count(r) AS total
729742
"""
743+
self.ensure_database(brain_id)
730744
result = self.driver.execute_query(cypher_query, database_=brain_id)
731745
count_result = self.driver.execute_query(cypher_count, database_=brain_id)
732746
total = 0
@@ -854,6 +868,7 @@ def search_entities(
854868
{"WHERE " + " AND ".join(filters) if filters else ""}
855869
RETURN count(n) AS total
856870
"""
871+
self.ensure_database(brain_id)
857872
result = self.driver.execute_query(cypher_query, database_=brain_id)
858873
count_result = self.driver.execute_query(cypher_count, database_=brain_id)
859874
total = 0

src/services/api/constants/requests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class RetrieveNeighborsAiModeRequestBody(BaseModel):
155155
...,
156156
description="The description of the neighbors to look for.",
157157
)
158+
brain_id: str = Field(
159+
default="default", description="The brain identifier to store the data in."
160+
)
158161

159162

160163
class RetrieveNeighborsWithIdentificationParamsRequestBody(BaseModel):
@@ -167,3 +170,14 @@ class RetrieveNeighborsWithIdentificationParamsRequestBody(BaseModel):
167170
description="The identification parameters of the entity to get neighbors for.",
168171
)
169172
limit: int = Field(10, description="The number of neighbors to return.")
173+
brain_id: str = Field(
174+
default="default", description="The brain identifier to store the data in."
175+
)
176+
177+
178+
class CreateBrainRequest(BaseModel):
179+
"""
180+
Request body for the create brain endpoint.
181+
"""
182+
183+
brain_id: str

src/services/api/controllers/system.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import asyncio
1212

1313
from pydantic import BaseModel
14+
from src.services.api.constants.requests import CreateBrainRequest
1415
from src.services.data.main import data_adapter
1516

1617

@@ -22,14 +23,6 @@ async def get_brains_list():
2223
return result
2324

2425

25-
class CreateBrainRequest(BaseModel):
26-
"""
27-
Request body for the create brain endpoint.
28-
"""
29-
30-
brain_id: str
31-
32-
3326
async def create_new_brain(request: CreateBrainRequest):
3427
"""
3528
Create a new brain

src/services/api/middlewares/auth.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,51 +23,50 @@ async def dispatch(self, request: Request, call_next):
2323
if request.method == "OPTIONS":
2424
return await call_next(request)
2525

26+
# Variables ----------------------------------------------
2627
brainpat = request.headers.get("BrainPAT") or getattr(
2728
request.state, "pat", None
2829
)
29-
brain_id = getattr(request.state, "brain_id", None)
30-
31-
cachepat_key = f"brainpat:{brain_id or 'default'}"
30+
system_pat = os.getenv("BRAINPAT_TOKEN")
3231

33-
use_only_system_pat = os.getenv("USE_ONLY_SYSTEM_PAT") == "true"
34-
35-
if use_only_system_pat:
36-
cachepat_key = "brainpat:system"
32+
if request.url.path.startswith("/system") or request.url.path == "/":
33+
if brainpat == system_pat:
34+
return await call_next(request)
35+
return JSONResponse(
36+
status_code=status.HTTP_401_UNAUTHORIZED,
37+
content={"detail": "Invalid or missing BrainPAT header"},
38+
)
3739

40+
brain_id = getattr(request.state, "brain_id", None)
41+
if not brain_id:
42+
return JSONResponse(
43+
status_code=status.HTTP_400_BAD_REQUEST,
44+
content={"detail": "Brain ID is required."},
45+
)
46+
cachepat_key = f"brainpat:{brain_id}"
3847
cached_brainpat = cache_adapter.get(key=cachepat_key, brain_id="system")
3948

40-
if not cached_brainpat and not use_only_system_pat:
49+
# Logic --------------------------------------------------
50+
if brainpat == system_pat:
51+
return await call_next(request)
52+
53+
if not cached_brainpat:
4154
stored_brain = data_adapter.get_brain(name_key=brain_id)
42-
system_pat = os.getenv("BRAINPAT_TOKEN")
43-
if not stored_brain:
44-
if brainpat != system_pat:
45-
return JSONResponse(
46-
status_code=status.HTTP_401_UNAUTHORIZED,
47-
content={"detail": "Invalid or missing BrainPAT header"},
48-
)
49-
cached_brainpat = system_pat
50-
else:
51-
cached_brainpat = stored_brain.pat
52-
cache_adapter.set(
53-
key=cachepat_key,
54-
value=stored_brain.pat,
55-
brain_id="system",
55+
if not stored_brain or stored_brain.pat != brainpat:
56+
return JSONResponse(
57+
status_code=status.HTTP_401_UNAUTHORIZED,
58+
content={"detail": "Invalid or missing BrainPAT header"},
5659
)
57-
if not cached_brainpat and use_only_system_pat:
58-
system_pat = os.getenv("BRAINPAT_TOKEN")
59-
cached_brainpat = system_pat
60+
cached_brainpat = stored_brain.pat
6061
cache_adapter.set(
61-
key="brainpat:system",
62-
value=system_pat,
63-
brain_id="system",
62+
key=cachepat_key, value=stored_brain.pat, brain_id="system"
6463
)
65-
66-
if cached_brainpat != brainpat:
64+
elif cached_brainpat != brainpat:
6765
return JSONResponse(
6866
status_code=status.HTTP_401_UNAUTHORIZED,
6967
content={"detail": "Invalid or missing BrainPAT header"},
7068
)
69+
7170
response = await call_next(request)
7271

7372
return response

0 commit comments

Comments
 (0)