Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/skvaider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def lifespan(app: FastAPI, registry: svcs.Registry):

model_config = skvaider.routers.openai.ModelConfig(config.openai.models)

pool = skvaider.routers.openai.Pool()
pool = skvaider.routers.openai.Pool(model_config)
for backend_config in config.backend:
if backend_config.type != "openai":
continue
Expand Down
2 changes: 2 additions & 0 deletions src/skvaider/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ async def verify_token(
# We could specify explicit exceptions here but go the safe route and just catch all in case the lib addes one
except Exception:
raise HTTPException(401, detail="Bad authentication")

return db_token
65 changes: 54 additions & 11 deletions src/skvaider/routers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import httpx
import structlog
import svcs
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, HTTPException, Request, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, Field

from skvaider import utils
from skvaider.auth import verify_token

router = APIRouter()

Expand All @@ -24,6 +25,18 @@
log = structlog.stdlib.get_logger()


def check_access(token: dict, model_id: str, model_config):
model_settings = model_config.get(model_id)
allowed_resource_groups = model_settings.get("resource_groups")
if allowed_resource_groups:
token_rg = token.get("resource_group")
if token_rg not in allowed_resource_groups:
raise HTTPException(
status_code=403,
detail=f"Access to model '{model_id}' is restricted.",
)


class AIModel(BaseModel):
"""Model object per backend."""

Expand Down Expand Up @@ -232,13 +245,17 @@ class Pool:
backends: list["Backend"]
health_check_tasks: list[asyncio.Task]
queues: dict[str, asyncio.Queue] # one queue per model
model_config: ModelConfig

def __init__(self):
def __init__(self, model_config: ModelConfig = None):
if model_config is None:
model_config = ModelConfig({})
self.backends = []
self.health_check_tasks = []
self.queues = {}
self.models = {}
self.queue_tasks = {}
self.model_config = model_config

def add_backend(self, backend):
self.backends.append(backend)
Expand Down Expand Up @@ -435,10 +452,13 @@ def __init__(self, services: svcs.fastapi.DepContainer):
self.services = services
self.pool = self.services.get(Pool)

async def proxy(self, request, endpoint, allow_stream=True):
async def proxy(self, request, endpoint, token: dict, allow_stream=True):
request_data = await request.json()
request_data["store"] = False
request.state.model = request_data["model"]

check_access(token, request.state.model, self.pool.model_config)

request.state.stream = allow_stream and request_data.get(
"stream", False
)
Expand Down Expand Up @@ -484,34 +504,57 @@ class ListResponse(BaseModel, Generic[T]):
@router.get("/v1/models")
async def list_models(
services: svcs.fastapi.DepContainer,
token: dict = Security(verify_token),
) -> ListResponse[AIModel]:
pool = services.get(Pool)
return ListResponse[AIModel](data=pool.models.values())
models = []
for m in pool.models.values():
try:
check_access(token, m.id, pool.model_config)
models.append(m)
except HTTPException:
pass
return ListResponse[AIModel](data=models)


@router.get("/v1/models/{model_id}")
async def get_model(
model_id: str, services: svcs.fastapi.DepContainer
model_id: str,
services: svcs.fastapi.DepContainer,
token: dict = Security(verify_token),
) -> AIModel:
pool = services.get(Pool)
check_access(token, model_id, pool.model_config)
return pool.models[model_id]


@router.post("/v1/chat/completions")
async def chat_completions(
r: Request, services: svcs.fastapi.DepContainer
r: Request,
services: svcs.fastapi.DepContainer,
token: dict = Security(verify_token),
) -> Any:
proxy = OpenAIProxy(services)
return await proxy.proxy(r, "/v1/chat/completions")
return await proxy.proxy(r, "/v1/chat/completions", token=token)


@router.post("/v1/completions")
async def completions(r: Request, services: svcs.fastapi.DepContainer) -> Any:
async def completions(
r: Request,
services: svcs.fastapi.DepContainer,
token: dict = Security(verify_token),
) -> Any:
proxy = OpenAIProxy(services)
return await proxy.proxy(r, "/v1/completions")
return await proxy.proxy(r, "/v1/completions", token=token)


@router.post("/v1/embeddings")
async def embeddings(r: Request, services: svcs.fastapi.DepContainer) -> Any:
async def embeddings(
r: Request,
services: svcs.fastapi.DepContainer,
token: dict = Security(verify_token),
) -> Any:
proxy = OpenAIProxy(services)
return await proxy.proxy(r, "/v1/embeddings", allow_stream=False)
return await proxy.proxy(
r, "/v1/embeddings", token=token, allow_stream=False
)
125 changes: 125 additions & 0 deletions src/skvaider/tests/test_access_restrictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import asyncio
import base64
import json

import pytest
import svcs
from fastapi.testclient import TestClient

from skvaider import app_factory
from skvaider.auth import AuthTokens
from skvaider.conftest import DUMMY_TOKENS, hasher
from skvaider.routers.openai import AIModel, Backend, ModelConfig, Pool


class MockBackend(Backend):
async def monitor_health_and_update_models(self, pool):
pool.update_model_maps()
while True:
await asyncio.sleep(1)


@pytest.fixture
def restricted_app_client(token_db):

config_data = {
"restricted-model": {"resource_groups": ["research"]},
"public-model": {},
}
model_config = ModelConfig(config_data)

pool = Pool(model_config)
backend = MockBackend("http://mock", model_config)

# Pre-populate backend models so they are ready when map updates
backend.models = {
"restricted-model": AIModel(
id="restricted-model", owned_by="me", backend=backend
),
"public-model": AIModel(
id="public-model", owned_by="me", backend=backend
),
}

@svcs.fastapi.lifespan
async def lifespan(app, registry):
# Add backend here where we have a loop
pool.add_backend(backend)
# Give it a tiny bit of time to update the pool maps (the task runs properly now)
await asyncio.sleep(0.01)

registry.register_value(Pool, pool)
registry.register_value(AuthTokens, DUMMY_TOKENS)
yield
pool.close()

app = app_factory(lifespan=lifespan)
with TestClient(app) as client:
yield client


def create_auth_header(resource_group=None):
secret = "secret"
token_data = {"id": "user", "secret": secret}

# Update the DB record which is what check_access reads
DUMMY_TOKENS.data["user"] = {
"secret_hash": hasher.hash(secret),
"resource_group": resource_group,
}

auth_token = base64.b64encode(
json.dumps(token_data).encode("utf-8")
).decode("ascii")
return {"Authorization": f"Bearer {auth_token}"}


def test_access_restricted_model_allowed(restricted_app_client):
headers = create_auth_header(resource_group="research")

resp = restricted_app_client.get("/openai/v1/models", headers=headers)
assert resp.status_code == 200
ids = [m["id"] for m in resp.json()["data"]]
assert "restricted-model" in ids

resp = restricted_app_client.get(
"/openai/v1/models/restricted-model", headers=headers
)
assert resp.status_code == 200


def test_access_restricted_model_denied(restricted_app_client):
headers = create_auth_header(resource_group="marketing")

resp = restricted_app_client.get("/openai/v1/models", headers=headers)
assert resp.status_code == 200
ids = [m["id"] for m in resp.json()["data"]]
assert "restricted-model" not in ids
assert "public-model" in ids

resp = restricted_app_client.get(
"/openai/v1/models/restricted-model", headers=headers
)
assert resp.status_code == 403


def test_access_public_model(restricted_app_client):
headers = create_auth_header(resource_group="marketing")
resp = restricted_app_client.get(
"/openai/v1/models/public-model", headers=headers
)
assert resp.status_code == 200


def test_access_restricted_model_no_resource_group(restricted_app_client):
headers = create_auth_header(resource_group=None)

resp = restricted_app_client.get("/openai/v1/models", headers=headers)
assert resp.status_code == 200
ids = [m["id"] for m in resp.json()["data"]]
assert "restricted-model" not in ids

resp = restricted_app_client.get(
"/openai/v1/models/restricted-model", headers=headers
)
assert resp.status_code == 403
Loading