diff --git a/src/skvaider/__init__.py b/src/skvaider/__init__.py index 18d3fca..e7a416f 100644 --- a/src/skvaider/__init__.py +++ b/src/skvaider/__init__.py @@ -67,7 +67,7 @@ async def lifespan(app: FastAPI, registry: svcs.Registry): model_config = skvaider.routers.openai.ModelConfig(config.openai.models) - pool = skvaider.routers.openai.Pool() + pool = skvaider.routers.openai.Pool(model_config) for backend_config in config.backend: if backend_config.type != "openai": continue diff --git a/src/skvaider/auth.py b/src/skvaider/auth.py index fa727a1..cfa19f3 100644 --- a/src/skvaider/auth.py +++ b/src/skvaider/auth.py @@ -43,3 +43,5 @@ async def verify_token( # We could specify explicit exceptions here but go the safe route and just catch all in case the lib addes one except Exception: raise HTTPException(401, detail="Bad authentication") + + return db_token diff --git a/src/skvaider/routers/openai.py b/src/skvaider/routers/openai.py index ee3579e..db1d86e 100644 --- a/src/skvaider/routers/openai.py +++ b/src/skvaider/routers/openai.py @@ -11,11 +11,12 @@ import httpx import structlog import svcs -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException, Request, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, Field from skvaider import utils +from skvaider.auth import verify_token router = APIRouter() @@ -24,6 +25,18 @@ log = structlog.stdlib.get_logger() +def check_access(token: dict, model_id: str, model_config): + model_settings = model_config.get(model_id) + allowed_resource_groups = model_settings.get("resource_groups") + if allowed_resource_groups: + token_rg = token.get("resource_group") + if token_rg not in allowed_resource_groups: + raise HTTPException( + status_code=403, + detail=f"Access to model '{model_id}' is restricted.", + ) + + class AIModel(BaseModel): """Model object per backend.""" @@ -232,13 +245,17 @@ class Pool: backends: list["Backend"] health_check_tasks: list[asyncio.Task] queues: dict[str, asyncio.Queue] # one queue per model + model_config: ModelConfig - def __init__(self): + def __init__(self, model_config: ModelConfig = None): + if model_config is None: + model_config = ModelConfig({}) self.backends = [] self.health_check_tasks = [] self.queues = {} self.models = {} self.queue_tasks = {} + self.model_config = model_config def add_backend(self, backend): self.backends.append(backend) @@ -435,10 +452,13 @@ def __init__(self, services: svcs.fastapi.DepContainer): self.services = services self.pool = self.services.get(Pool) - async def proxy(self, request, endpoint, allow_stream=True): + async def proxy(self, request, endpoint, token: dict, allow_stream=True): request_data = await request.json() request_data["store"] = False request.state.model = request_data["model"] + + check_access(token, request.state.model, self.pool.model_config) + request.state.stream = allow_stream and request_data.get( "stream", False ) @@ -484,34 +504,57 @@ class ListResponse(BaseModel, Generic[T]): @router.get("/v1/models") async def list_models( services: svcs.fastapi.DepContainer, + token: dict = Security(verify_token), ) -> ListResponse[AIModel]: pool = services.get(Pool) - return ListResponse[AIModel](data=pool.models.values()) + models = [] + for m in pool.models.values(): + try: + check_access(token, m.id, pool.model_config) + models.append(m) + except HTTPException: + pass + return ListResponse[AIModel](data=models) @router.get("/v1/models/{model_id}") async def get_model( - model_id: str, services: svcs.fastapi.DepContainer + model_id: str, + services: svcs.fastapi.DepContainer, + token: dict = Security(verify_token), ) -> AIModel: pool = services.get(Pool) + check_access(token, model_id, pool.model_config) return pool.models[model_id] @router.post("/v1/chat/completions") async def chat_completions( - r: Request, services: svcs.fastapi.DepContainer + r: Request, + services: svcs.fastapi.DepContainer, + token: dict = Security(verify_token), ) -> Any: proxy = OpenAIProxy(services) - return await proxy.proxy(r, "/v1/chat/completions") + return await proxy.proxy(r, "/v1/chat/completions", token=token) @router.post("/v1/completions") -async def completions(r: Request, services: svcs.fastapi.DepContainer) -> Any: +async def completions( + r: Request, + services: svcs.fastapi.DepContainer, + token: dict = Security(verify_token), +) -> Any: proxy = OpenAIProxy(services) - return await proxy.proxy(r, "/v1/completions") + return await proxy.proxy(r, "/v1/completions", token=token) @router.post("/v1/embeddings") -async def embeddings(r: Request, services: svcs.fastapi.DepContainer) -> Any: +async def embeddings( + r: Request, + services: svcs.fastapi.DepContainer, + token: dict = Security(verify_token), +) -> Any: proxy = OpenAIProxy(services) - return await proxy.proxy(r, "/v1/embeddings", allow_stream=False) + return await proxy.proxy( + r, "/v1/embeddings", token=token, allow_stream=False + ) diff --git a/src/skvaider/tests/test_access_restrictions.py b/src/skvaider/tests/test_access_restrictions.py new file mode 100644 index 0000000..59f5c1a --- /dev/null +++ b/src/skvaider/tests/test_access_restrictions.py @@ -0,0 +1,125 @@ +import asyncio +import base64 +import json + +import pytest +import svcs +from fastapi.testclient import TestClient + +from skvaider import app_factory +from skvaider.auth import AuthTokens +from skvaider.conftest import DUMMY_TOKENS, hasher +from skvaider.routers.openai import AIModel, Backend, ModelConfig, Pool + + +class MockBackend(Backend): + async def monitor_health_and_update_models(self, pool): + pool.update_model_maps() + while True: + await asyncio.sleep(1) + + +@pytest.fixture +def restricted_app_client(token_db): + + config_data = { + "restricted-model": {"resource_groups": ["research"]}, + "public-model": {}, + } + model_config = ModelConfig(config_data) + + pool = Pool(model_config) + backend = MockBackend("http://mock", model_config) + + # Pre-populate backend models so they are ready when map updates + backend.models = { + "restricted-model": AIModel( + id="restricted-model", owned_by="me", backend=backend + ), + "public-model": AIModel( + id="public-model", owned_by="me", backend=backend + ), + } + + @svcs.fastapi.lifespan + async def lifespan(app, registry): + # Add backend here where we have a loop + pool.add_backend(backend) + # Give it a tiny bit of time to update the pool maps (the task runs properly now) + await asyncio.sleep(0.01) + + registry.register_value(Pool, pool) + registry.register_value(AuthTokens, DUMMY_TOKENS) + yield + pool.close() + + app = app_factory(lifespan=lifespan) + with TestClient(app) as client: + yield client + + +def create_auth_header(resource_group=None): + secret = "secret" + token_data = {"id": "user", "secret": secret} + + # Update the DB record which is what check_access reads + DUMMY_TOKENS.data["user"] = { + "secret_hash": hasher.hash(secret), + "resource_group": resource_group, + } + + auth_token = base64.b64encode( + json.dumps(token_data).encode("utf-8") + ).decode("ascii") + return {"Authorization": f"Bearer {auth_token}"} + + +def test_access_restricted_model_allowed(restricted_app_client): + headers = create_auth_header(resource_group="research") + + resp = restricted_app_client.get("/openai/v1/models", headers=headers) + assert resp.status_code == 200 + ids = [m["id"] for m in resp.json()["data"]] + assert "restricted-model" in ids + + resp = restricted_app_client.get( + "/openai/v1/models/restricted-model", headers=headers + ) + assert resp.status_code == 200 + + +def test_access_restricted_model_denied(restricted_app_client): + headers = create_auth_header(resource_group="marketing") + + resp = restricted_app_client.get("/openai/v1/models", headers=headers) + assert resp.status_code == 200 + ids = [m["id"] for m in resp.json()["data"]] + assert "restricted-model" not in ids + assert "public-model" in ids + + resp = restricted_app_client.get( + "/openai/v1/models/restricted-model", headers=headers + ) + assert resp.status_code == 403 + + +def test_access_public_model(restricted_app_client): + headers = create_auth_header(resource_group="marketing") + resp = restricted_app_client.get( + "/openai/v1/models/public-model", headers=headers + ) + assert resp.status_code == 200 + + +def test_access_restricted_model_no_resource_group(restricted_app_client): + headers = create_auth_header(resource_group=None) + + resp = restricted_app_client.get("/openai/v1/models", headers=headers) + assert resp.status_code == 200 + ids = [m["id"] for m in resp.json()["data"]] + assert "restricted-model" not in ids + + resp = restricted_app_client.get( + "/openai/v1/models/restricted-model", headers=headers + ) + assert resp.status_code == 403