Skip to content

Commit 50e3928

Browse files
feat(pat): multiple pat x brain
1 parent d437de2 commit 50e3928

File tree

9 files changed

+121
-11
lines changed

9 files changed

+121
-11
lines changed

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ MONGO_PASSWORD="password"
4242
# Auth
4343
# Used to handle authentication
4444
BRAINPAT_TOKEN="your_token"
45+
46+
# MultiBrain
47+
# Choose to allow or block automatic creation of new brains on requests with non existing new brain_ids
48+
BRAIN_CREATION_ALLOWED="true"
49+
# Choose whether to use for every brain the main pat or to use dedicated one for each brain
50+
USE_ONLY_SYSTEM_PAT="false"

src/constants/data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,17 @@ class Brain(BaseModel):
9090

9191
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
9292
name_key: str = Field(description="The key used to identify the brain.")
93+
94+
@staticmethod
95+
def _random_pat() -> str:
96+
import random
97+
98+
chars = []
99+
for _ in range(48):
100+
chars.append(random.choice("abcdefghijklmnopqrstuvwxyz0123456789"))
101+
return "".join(chars)
102+
103+
pat: str = Field(
104+
description="The personal access token for the brain.",
105+
default_factory=_random_pat,
106+
)

src/lib/mongo/client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,32 @@ def save_structured_data(
134134

135135
def create_brain(self, name_key: str) -> Brain:
136136
collection = self.get_collection("brains", "system")
137-
result = collection.insert_one({"name_key": name_key})
138-
return Brain(id=str(result.inserted_id), name_key=name_key)
137+
brain = Brain(name_key=name_key)
138+
brain_dict = brain.model_dump(mode="json", exclude={"id"})
139+
result = collection.insert_one(brain_dict)
140+
brain.id = str(result.inserted_id)
141+
return brain
139142

140143
def get_brain(self, name_key: str) -> Brain:
141144
collection = self.get_collection("brains", "system")
142145
result = collection.find_one({"name_key": name_key})
143146
if not result:
144147
return None
145-
return Brain(id=str(result["_id"]), name_key=result["name_key"])
148+
return Brain(
149+
id=str(result["_id"]),
150+
name_key=result["name_key"],
151+
pat=result.get("pat"),
152+
)
146153

147154
def get_brains_list(self) -> List[Brain]:
148155
collection = self.get_collection("brains", "system")
149156
result = collection.find()
150157
return [
151-
Brain(id=str(result["_id"]), name_key=result["name_key"])
158+
Brain(
159+
id=str(result["_id"]),
160+
name_key=result["name_key"],
161+
pat=result.get("pat"),
162+
)
152163
for result in result
153164
]
154165

src/lib/redis/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ def get(self, key: str, brain_id: str) -> str:
4343
Get a value from the cache.
4444
"""
4545
prefixed_key = self._get_key(key, brain_id)
46-
return self.client.get(prefixed_key)
46+
result = self.client.get(prefixed_key)
47+
if result is None:
48+
return None
49+
if isinstance(result, bytes):
50+
return result.decode("utf-8")
51+
return result
4752

4853
def set(
4954
self, key: str, value: str, brain_id: str, expires_in: Optional[int] = None

src/services/api/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
app = FastAPI()
2626

2727
app.add_middleware(BrainPATMiddleware)
28+
app.add_middleware(BrainMiddleware)
2829
app.add_middleware(
2930
CORSMiddleware,
3031
allow_origins=["*"],
3132
allow_credentials=False,
3233
allow_methods=["*"],
3334
allow_headers=["*"],
3435
)
35-
app.add_middleware(BrainMiddleware)
36+
3637

3738
app.include_router(ingest_router)
3839
app.include_router(retrieve_router)

src/services/api/controllers/system.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,11 @@ async def get_brains_list():
1818
"""
1919
result = await asyncio.to_thread(data_adapter.get_brains_list)
2020
return result
21+
22+
23+
async def create_new_brain(brain_id: str):
24+
"""
25+
Create a new brain
26+
"""
27+
result = await asyncio.to_thread(data_adapter.create_brain, brain_id)
28+
return result

src/services/api/middlewares/auth.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,64 @@
88
-----
99
"""
1010

11+
import os
1112
from fastapi import Request
1213
from fastapi.responses import JSONResponse
1314
from starlette.middleware.base import BaseHTTPMiddleware
1415
from starlette import status
1516

16-
from src.config import config
17+
from src.services.kg_agent.main import cache_adapter
18+
from src.services.data.main import data_adapter
1719

1820

1921
class BrainPATMiddleware(BaseHTTPMiddleware):
2022
async def dispatch(self, request: Request, call_next):
2123
if request.method == "OPTIONS":
2224
return await call_next(request)
23-
brainpat = request.headers.get("BrainPAT")
24-
if brainpat != config.brainpat_token:
25+
26+
brainpat = request.headers.get("BrainPAT") or getattr(
27+
request.state, "pat", None
28+
)
29+
brain_id = getattr(request.state, "brain_id", None)
30+
31+
cachepat_key = f"brainpat:{brain_id or 'default'}"
32+
33+
use_only_system_pat = os.getenv("USE_ONLY_SYSTEM_PAT") == "true"
34+
35+
if use_only_system_pat:
36+
cachepat_key = "brainpat:system"
37+
38+
cached_brainpat = cache_adapter.get(key=cachepat_key, brain_id="system")
39+
40+
if not cached_brainpat and not use_only_system_pat:
41+
stored_brain = data_adapter.get_brain(name_key=brain_id)
42+
system_pat = os.getenv("BRAINPAT_TOKEN")
43+
if not stored_brain and brainpat != system_pat:
44+
return JSONResponse(
45+
status_code=status.HTTP_401_UNAUTHORIZED,
46+
content={"detail": "Invalid or missing BrainPAT header"},
47+
)
48+
49+
cached_brainpat = stored_brain.pat
50+
cache_adapter.set(
51+
key=cachepat_key,
52+
value=stored_brain.pat,
53+
brain_id="system",
54+
)
55+
if not cached_brainpat and use_only_system_pat:
56+
system_pat = os.getenv("BRAINPAT_TOKEN")
57+
cached_brainpat = system_pat
58+
cache_adapter.set(
59+
key="brainpat:system",
60+
value=system_pat,
61+
brain_id="system",
62+
)
63+
64+
if cached_brainpat != brainpat:
2565
return JSONResponse(
2666
status_code=status.HTTP_401_UNAUTHORIZED,
2767
content={"detail": "Invalid or missing BrainPAT header"},
2868
)
2969
response = await call_next(request)
70+
3071
return response

src/services/api/middlewares/brains.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
-----
99
"""
1010

11+
import os
1112
from fastapi import Request
1213
from fastapi.responses import JSONResponse
1314
from starlette.middleware.base import BaseHTTPMiddleware
@@ -62,9 +63,11 @@ async def receive():
6263
key=f"brain:{brain_id}", brain_id="system"
6364
)
6465

65-
if not cached_brain_id:
66-
stored_brain = data_adapter.get_brain(name_key=brain_id)
66+
brain_creation_allowed = os.getenv("BRAIN_CREATION_ALLOWED") == "true"
6767

68+
if not cached_brain_id and brain_creation_allowed:
69+
stored_brain = data_adapter.get_brain(name_key=brain_id)
70+
new_brain = None
6871
if not stored_brain:
6972
new_brain = data_adapter.create_brain(name_key=brain_id)
7073
cache_adapter.set(
@@ -79,6 +82,18 @@ async def receive():
7982
brain_id="system",
8083
)
8184

85+
cached_brain_id = new_brain.id if new_brain else stored_brain.id
86+
87+
use_only_system_pat = os.getenv("USE_ONLY_SYSTEM_PAT") == "true"
88+
if not use_only_system_pat:
89+
request.state.pat = new_brain.pat if new_brain else stored_brain.pat
90+
91+
if not cached_brain_id:
92+
return JSONResponse(
93+
status_code=status.HTTP_406_NOT_ACCEPTABLE,
94+
content={"detail": "Brain not found or creation is not allowed."},
95+
)
96+
8297
request.state.brain_id = brain_id
8398

8499
return await call_next(request)

src/services/api/routes/system.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi import APIRouter
1212
from src.services.api.controllers.system import (
1313
get_brains_list as get_brains_list_controller,
14+
create_new_brain as create_new_brain_controller,
1415
)
1516

1617
system_router = APIRouter(prefix="/system", tags=["system"])
@@ -22,3 +23,11 @@ async def get_brains_list():
2223
Get the list of brains.
2324
"""
2425
return await get_brains_list_controller()
26+
27+
28+
@system_router.post(path="/brains")
29+
async def create_brain():
30+
"""
31+
Create a new brain
32+
"""
33+
return await create_new_brain_controller()

0 commit comments

Comments
 (0)