Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/7705.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Trim down unnecessary informations from model service JWTs to make model service token smaller
11 changes: 3 additions & 8 deletions src/ai/backend/appproxy/coordinator/api/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from uuid import UUID

import aiohttp_cors
import jwt
import sqlalchemy as sa
from aiohttp import web
from pydantic import AnyUrl, BaseModel, Field
Expand Down Expand Up @@ -297,13 +296,9 @@ async def generate_endpoint_api_token(
circuit: Circuit = await Circuit.find_by_endpoint(
sess, UUID(request.match_info["endpoint_id"]), load_worker=False, load_endpoint=False
)
payload = dict(circuit.dump_model())
payload["config"] = {}
payload["app_url"] = str(await circuit.get_endpoint_url(session=sess))

payload["user"] = str(params.user_uuid)
payload["exp"] = params.exp
encoded_jwt = jwt.encode(payload, root_ctx.local_config.secrets.jwt_secret, algorithm="HS256")
encoded_jwt = await circuit.generate_jwt(
sess, root_ctx.local_config.secrets.jwt_secret, params.user_uuid, params.exp
)
return PydanticResponse(EndpointAPITokenResponseModel(token=encoded_jwt))


Expand Down
16 changes: 16 additions & 0 deletions src/ai/backend/appproxy/coordinator/models/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Optional
from uuid import UUID

import jwt
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -414,6 +415,21 @@ def update_route_health_status(
return did_update_status
return False

async def generate_jwt(
self, db_sess: AsyncSession, jwt_secret: str, created_user: UUID, exp: datetime
) -> str:
payload = dict(self.dump_model())

# inject extra information
payload["app_url"] = str(await self.get_endpoint_url(session=db_sess))
payload["user"] = str(created_user)
payload["exp"] = exp
# mask unrelated & sensitive information
del payload["config"]
del payload["route_info"]

return jwt.encode(payload, jwt_secret, algorithm="HS256")

@property
def traefik_services(self) -> dict[str, Any]:
# Use health-aware route filtering
Expand Down
Loading