diff --git a/changes/7705.feature.md b/changes/7705.feature.md new file mode 100644 index 00000000000..642d4eb0375 --- /dev/null +++ b/changes/7705.feature.md @@ -0,0 +1 @@ +Trim down unnecessary informations from model service JWTs to make model service token smaller diff --git a/src/ai/backend/appproxy/coordinator/api/endpoint.py b/src/ai/backend/appproxy/coordinator/api/endpoint.py index 506be3e1bb1..98c55a0d85b 100644 --- a/src/ai/backend/appproxy/coordinator/api/endpoint.py +++ b/src/ai/backend/appproxy/coordinator/api/endpoint.py @@ -7,7 +7,6 @@ from uuid import UUID import aiohttp_cors -import jwt import sqlalchemy as sa from aiohttp import web from pydantic import AnyUrl, BaseModel, Field @@ -297,13 +296,9 @@ async def generate_endpoint_api_token( circuit: Circuit = await Circuit.find_by_endpoint( sess, UUID(request.match_info["endpoint_id"]), load_worker=False, load_endpoint=False ) - payload = dict(circuit.dump_model()) - payload["config"] = {} - payload["app_url"] = str(await circuit.get_endpoint_url(session=sess)) - - payload["user"] = str(params.user_uuid) - payload["exp"] = params.exp - encoded_jwt = jwt.encode(payload, root_ctx.local_config.secrets.jwt_secret, algorithm="HS256") + encoded_jwt = await circuit.generate_jwt( + sess, root_ctx.local_config.secrets.jwt_secret, params.user_uuid, params.exp + ) return PydanticResponse(EndpointAPITokenResponseModel(token=encoded_jwt)) diff --git a/src/ai/backend/appproxy/coordinator/models/circuit.py b/src/ai/backend/appproxy/coordinator/models/circuit.py index fd0f5676cbf..64cf58c2b2b 100644 --- a/src/ai/backend/appproxy/coordinator/models/circuit.py +++ b/src/ai/backend/appproxy/coordinator/models/circuit.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Optional from uuid import UUID +import jwt import sqlalchemy as sa from sqlalchemy.dialects import postgresql as pgsql from sqlalchemy.ext.asyncio import AsyncSession @@ -414,6 +415,21 @@ def update_route_health_status( return did_update_status return False + async def generate_jwt( + self, db_sess: AsyncSession, jwt_secret: str, created_user: UUID, exp: datetime + ) -> str: + payload = dict(self.dump_model()) + + # inject extra information + payload["app_url"] = str(await self.get_endpoint_url(session=db_sess)) + payload["user"] = str(created_user) + payload["exp"] = exp + # mask unrelated & sensitive information + del payload["config"] + del payload["route_info"] + + return jwt.encode(payload, jwt_secret, algorithm="HS256") + @property def traefik_services(self) -> dict[str, Any]: # Use health-aware route filtering