Skip to content

Commit c6d7d2e

Browse files
committed
wip
1 parent eaef29a commit c6d7d2e

File tree

12 files changed

+79
-26
lines changed

12 files changed

+79
-26
lines changed

agents/chat/src/chat/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AgentSkill,
1010
Message,
1111
)
12+
from agentstack_sdk.server.middleware.platform_auth_backend import PlatformAuthBackend
1213
from beeai_framework.agents.requirement.utils._tool import FinalAnswerTool
1314
from beeai_framework.errors import FrameworkError
1415
from pydantic import BaseModel
@@ -300,6 +301,7 @@ def serve():
300301
port=int(os.getenv("PORT", 8000)),
301302
configure_telemetry=True,
302303
context_store=PlatformContextStore(),
304+
auth_backend=PlatformAuthBackend(skip_audience_validation=True),
303305
)
304306
except KeyboardInterrupt:
305307
pass

apps/agentstack-cli/src/agentstack_cli/api.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import openai
1515
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
1616
from a2a.types import AgentCard
17+
from agentstack_sdk.platform.context import ContextToken
1718
from httpx import HTTPStatusError
1819
from httpx._types import RequestFiles
1920

@@ -103,14 +104,10 @@ async def api_stream(
103104

104105

105106
@asynccontextmanager
106-
async def a2a_client(agent_card: AgentCard, use_auth: bool = True) -> AsyncIterator[Client]:
107+
async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> AsyncIterator[Client]:
107108
try:
108109
async with httpx.AsyncClient(
109-
headers=(
110-
{"Authorization": f"Bearer {token}"}
111-
if use_auth and (token := await config.auth_manager.load_auth_token())
112-
else {}
113-
),
110+
headers={"Authorization": f"Bearer {context_token.token.get_secret_value()}"},
114111
follow_redirects=True,
115112
timeout=timedelta(hours=1).total_seconds(),
116113
) as httpx_client:

apps/agentstack-cli/src/agentstack_cli/commands/agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ async def _run_agent(
556556
console.print() # Add newline after completion
557557
return
558558
case Task(id=task_id), TaskStatusUpdateEvent(
559-
status=TaskStatus(state=TaskState.working, message=message)
559+
status=TaskStatus(state=TaskState.working | TaskState.submitted, message=message)
560560
):
561561
# Handle streaming content during working state
562562
if message:
@@ -944,7 +944,7 @@ async def run_agent(
944944
if interaction_mode == InteractionMode.MULTI_TURN:
945945
console.print(f"{user_greeting}\n")
946946
turn_input = await _ask_form_questions(initial_form_render) if initial_form_render else handle_input()
947-
async with a2a_client(provider.agent_card) as client:
947+
async with a2a_client(provider.agent_card, context_token=context_token) as client:
948948
while True:
949949
console.print()
950950
await _run_agent(
@@ -961,7 +961,7 @@ async def run_agent(
961961
user_greeting = ui_annotations.get("user_greeting", None) or "Enter your instructions."
962962
console.print(f"{user_greeting}\n")
963963
console.print()
964-
async with a2a_client(provider.agent_card) as client:
964+
async with a2a_client(provider.agent_card, context_token=context_token) as client:
965965
await _run_agent(
966966
client,
967967
input=await _ask_form_questions(initial_form_render) if initial_form_render else handle_input(),
@@ -972,7 +972,7 @@ async def run_agent(
972972
)
973973

974974
else:
975-
async with a2a_client(provider.agent_card) as client:
975+
async with a2a_client(provider.agent_card, context_token=context_token) as client:
976976
await _run_agent(
977977
client,
978978
input,

apps/agentstack-cli/uv.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/services/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _get_header_token(self, request_context: RequestContext) -> pydantic.Secret[
7373
assert call_context
7474
if isinstance(call_context.user, PlatformAuthenticatedUser):
7575
header_token = call_context.user.auth_token.get_secret_value()
76-
elif (headers := call_context.state.get("headers")) and (auth_header := headers.get("authorization")):
76+
elif auth_header := call_context.state.get("headers", {}).get("authorization", None):
7777
_scheme, header_token = get_authorization_scheme_param(auth_header)
7878
return pydantic.Secret(header_token) if header_token else None
7979

apps/agentstack-sdk-py/src/agentstack_sdk/server/middleware/platform_auth_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, public_url: str | None = None, skip_audience_validation: bool
7676
if skip_audience_validation is not None
7777
else os.getenv("PLATFORM_AUTH__SKIP_AUDIENCE_VALIDATION", "false").lower() in ("true", "1")
7878
)
79-
_audience = public_url or os.getenv("PLATFORM_AUTH__PUBLIC_URL")
79+
_audience = public_url or os.getenv("PLATFORM_AUTH__PUBLIC_URL", "http://host.docker.internal:8333")
8080
if not self.skip_audience_validation and not _audience:
8181
raise ValueError(
8282
"Public URL must be provided if audience validation is enabled (hint: set PLATFORM_AUTH__PUBLIC_URL env variable)"
@@ -89,7 +89,7 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas
8989
# We construct a Request object from the scope for compatibility with HTTPBearer and logging
9090
request = Request(scope=conn.scope)
9191

92-
if request.url.path in ["/", "/healthcheck", "/.well-known/agent-card.json"]:
92+
if request.url.path in ["/healthcheck", "/.well-known/agent-card.json"]:
9393
return None
9494

9595
if not (auth := await self.security(request)):

apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
from a2a.types import AgentExtension
2222
from fastapi import FastAPI
2323
from fastapi.applications import AppType
24+
from fastapi.responses import PlainTextResponse
2425
from httpx import HTTPError, HTTPStatusError
2526
from pydantic import AnyUrl
27+
from starlette.authentication import AuthenticationBackend, AuthenticationError
28+
from starlette.middleware.authentication import AuthenticationMiddleware
29+
from starlette.requests import HTTPConnection
2630
from starlette.types import Lifespan
2731
from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential
2832

@@ -127,7 +131,7 @@ async def serve(
127131
factory: bool = False,
128132
h11_max_incomplete_event_size: int | None = None,
129133
self_registration_client_factory: Callable[[], PlatformClient] | None = None,
130-
auth_middleware: Any | None = None,
134+
auth_backend: AuthenticationBackend | None = None,
131135
) -> None:
132136
if self.server:
133137
raise RuntimeError("The server is already running")
@@ -199,6 +203,13 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]:
199203
request_context_builder=request_context_builder,
200204
)
201205

206+
if auth_backend:
207+
208+
def on_error(connection: HTTPConnection, error: AuthenticationError) -> PlainTextResponse:
209+
return PlainTextResponse("Unauthorized", status_code=401)
210+
211+
app.add_middleware(AuthenticationMiddleware, backend=auth_backend, on_error=on_error)
212+
202213
if configure_logger:
203214
configure_logger_func(log_level)
204215

apps/agentstack-server/src/agentstack_server/api/auth/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from authlib.oidc.discovery import OpenIDProviderMetadata
1818
from authlib.oidc.discovery import get_well_known_url as oidc_get_well_known_url
1919
from fastapi.security import HTTPAuthorizationCredentials
20+
from kink import inject
2021
from pydantic import AwareDatetime, BaseModel
2122

2223
from agentstack_server.api.auth.errors import (

apps/agentstack-server/src/agentstack_server/api/routes/auth.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,5 @@ def protected_resource_metadata(
2626
@well_known_router.get("/jwks")
2727
def jwks():
2828
config = get_configuration()
29-
if not config.auth.jwt_public_key or config.auth.jwt_public_key.get_secret_value() == "dummy":
30-
return {"keys": []}
31-
3229
key = JsonWebKey.import_key(config.auth.jwt_public_key.get_secret_value(), {"use": "sig", "alg": "RS256"})
3330
return {"keys": [key.as_dict()]}

apps/agentstack-server/src/agentstack_server/configuration.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,17 @@ def validate_auth(self):
134134
raise ValueError("JWT private and public keys must be provided if authentication is enabled")
135135
return self
136136

137+
@model_validator(mode="after")
138+
def set_default_jwt_keys(self):
139+
if self.jwt_private_key.get_secret_value() == "dummy" or self.jwt_public_key.get_secret_value() == "dummy":
140+
logger.warning("JWT private and public keys are not set. Generating default keys.")
141+
from authlib.jose import JsonWebKey
142+
143+
key = JsonWebKey.generate_key("RSA", 1024, is_private=True)
144+
self.jwt_private_key = Secret(key.as_pem(is_private=True).decode("utf-8"))
145+
self.jwt_public_key = Secret(key.as_pem(is_private=False).decode("utf-8"))
146+
return self
147+
137148

138149
class McpConfiguration(BaseModel):
139150
gateway_endpoint_url: AnyUrl = AnyUrl("http://forge-svc:4444")

0 commit comments

Comments
 (0)