Skip to content

Commit 2cba3e2

Browse files
authored
add mcp-headers support (#1716)
Add mcp headers support to allow authentication to mcp servers and their target AAP services issue: https://issues.redhat.com/browse/AAP-49395 Signed-off-by: Djebran Lezzoum <[email protected]>
1 parent bd6adf8 commit 2cba3e2

File tree

8 files changed

+111
-6
lines changed

8 files changed

+111
-6
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/configuration.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
# ENABLE_HEALTHCHECK_XXX
3232

3333

34+
class MCPServersSerializer(serializers.Serializer):
35+
name = serializers.CharField(required=True)
36+
type = serializers.CharField(required=True)
37+
38+
3439
@dataclass
3540
class HttpConfiguration(BaseConfig):
3641

@@ -42,13 +47,16 @@ def __init__(
4247
enable_health_check: Optional[bool],
4348
verify_ssl: bool,
4449
stream: bool = False,
50+
mcp_servers: Optional[list[dict[str, str]]] = None,
4551
):
4652
super().__init__(inference_url, model_id, timeout, enable_health_check)
4753
self.verify_ssl = verify_ssl
4854
self.stream = stream
55+
self.mcp_servers = mcp_servers or []
4956

5057
verify_ssl: bool
5158
stream: bool
59+
mcp_servers: Optional[list[dict[str, str]]] = None
5260

5361

5462
@Register(api_type="http")
@@ -64,6 +72,7 @@ def __init__(self, **kwargs):
6472
enable_health_check=kwargs["enable_health_check"],
6573
verify_ssl=kwargs["verify_ssl"],
6674
stream=kwargs["stream"],
75+
mcp_servers=kwargs["mcp_servers"],
6776
),
6877
)
6978

@@ -72,3 +81,6 @@ def __init__(self, **kwargs):
7281
class HttpConfigurationSerializer(BaseConfigSerializer):
7382
verify_ssl = serializers.BooleanField(required=False, default=True)
7483
stream = serializers.BooleanField(required=False, default=False)
84+
mcp_servers = serializers.ListSerializer(
85+
child=MCPServersSerializer(), allow_empty=True, required=False, default=None
86+
)

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import aiohttp
2121
import requests
22+
from django.conf import settings
2223
from django.http import StreamingHttpResponse
2324
from health_check.exceptions import ServiceUnavailable
2425

@@ -176,7 +177,7 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
176177
conversation_id = params.conversation_id
177178
provider = params.provider
178179
model_id = params.model_id
179-
system_prompt = params.system_prompt
180+
system_prompt = params.system_prompt or settings.CHATBOT_DEFAULT_SYSTEM_PROMPT
180181

181182
data = {
182183
"query": query,
@@ -188,6 +189,10 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
188189
if system_prompt:
189190
data["system_prompt"] = str(system_prompt)
190191

192+
headers = self.headers or {}
193+
if params.mcp_headers:
194+
headers["MCP-HEADERS"] = json.dumps(params.mcp_headers)
195+
191196
response = requests.post(
192197
self.config.inference_url + "/v1/query",
193198
headers=self.headers,
@@ -262,11 +267,14 @@ async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerat
262267
"Accept": "application/json,text/event-stream",
263268
}
264269

270+
if params.mcp_headers:
271+
headers["MCP-HEADERS"] = json.dumps(params.mcp_headers)
272+
265273
query = params.query
266274
conversation_id = params.conversation_id
267275
provider = params.provider
268276
model_id = params.model_id
269-
system_prompt = params.system_prompt
277+
system_prompt = params.system_prompt or settings.CHATBOT_DEFAULT_SYSTEM_PROMPT
270278
media_type = params.media_type
271279

272280
data = {

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from abc import ABCMeta, abstractmethod
1717
from typing import Any, Dict, Generic, Optional
1818

19-
from attrs import define
19+
from attrs import define, field
2020
from django.conf import settings
2121
from django.http import StreamingHttpResponse
2222
from rest_framework import serializers
@@ -231,6 +231,7 @@ class ChatBotParameters:
231231
model_id: str
232232
conversation_id: Optional[str]
233233
system_prompt: str
234+
mcp_headers: Optional[dict[str, dict[str, str]]] = field(kw_only=True, default=None)
234235

235236
@classmethod
236237
def init(
@@ -240,13 +241,15 @@ def init(
240241
model_id: Optional[str] = None,
241242
conversation_id: Optional[str] = None,
242243
system_prompt: Optional[str] = None,
244+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
243245
):
244246
return cls(
245247
query=query,
246248
provider=provider,
247249
model_id=model_id,
248250
conversation_id=conversation_id,
249251
system_prompt=system_prompt,
252+
mcp_headers=mcp_headers,
250253
)
251254

252255

@@ -266,6 +269,7 @@ def init(
266269
conversation_id: Optional[str] = None,
267270
system_prompt: Optional[str] = None,
268271
media_type: Optional[str] = None,
272+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
269273
):
270274
return cls(
271275
query=query,
@@ -274,6 +278,7 @@ def init(
274278
conversation_id=conversation_id,
275279
system_prompt=system_prompt,
276280
media_type=media_type,
281+
mcp_headers=mcp_headers,
277282
)
278283

279284

ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs):
8181
enable_health_check=extract("enable_health_check", False, **kwargs),
8282
verify_ssl=extract("verify_ssl", False, **kwargs),
8383
stream=extract("stream", False, **kwargs),
84+
mcp_servers=extract("mcp_servers", [], **kwargs),
8485
)
8586
case "llamacpp":
8687
return LlamaCppConfiguration(

ansible_ai_connect/ai/api/versions/v1/tests/test_jwt_authentication.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
from datetime import datetime, timedelta
66
from http import HTTPStatus
77
from unittest import mock
8+
from unittest.mock import Mock, patch
89

910
import jwt
1011
import requests
1112
from ansible_base.resource_registry.models.service_identifier import service_id
1213
from cryptography.hazmat.backends import default_backend
1314
from cryptography.hazmat.primitives import serialization
1415
from cryptography.hazmat.primitives.asymmetric import rsa
16+
from django.apps import apps
1517
from django.test import override_settings
1618
from rest_framework.test import APITransactionTestCase
1719

20+
from ansible_ai_connect.ai.api.model_pipelines.http.pipelines import HttpChatBotPipeline
21+
from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config
1822
from ansible_ai_connect.ai.api.versions.v1.test_base import API_VERSION
1923
from ansible_ai_connect.test_utils import APIVersionTestCaseBase
2024
from ansible_ai_connect.users.models import User
@@ -40,8 +44,35 @@
4044
ISSUER = "ansible-issuer"
4145
AUDIENCE = "ansible-services"
4246

47+
ANSIBLE_AI_MODEL_MESH_CONFIG = {
48+
"ModelPipelineChatBot": {
49+
"provider": "http",
50+
"config": {
51+
"inference_url": "http://localhost:8080",
52+
"model_id": "granite-3.3-8b-instruct",
53+
"enable_health_check": True,
54+
"mcp_servers": [
55+
{"name": "mcp::aap-controller", "type": "controller"},
56+
{"name": "mcp::aap-gateway", "type": "gateway"},
57+
{"name": "mcp::aap-lightspeed", "type": "lightspeed"},
58+
],
59+
},
60+
},
61+
}
62+
4363

4464
@override_settings(ANSIBLE_BASE_JWT_KEY=test_encryption_public_key)
65+
@patch.object(
66+
apps.get_app_config("ai"),
67+
"get_model_pipeline",
68+
Mock(
69+
return_value=HttpChatBotPipeline(
70+
mock_pipeline_config(
71+
"http", **ANSIBLE_AI_MODEL_MESH_CONFIG["ModelPipelineChatBot"]["config"]
72+
),
73+
)
74+
),
75+
)
4576
class TestJWTAuthentication(APIVersionTestCaseBase, APITransactionTestCase):
4677
api_version = API_VERSION
4778

@@ -113,5 +144,23 @@ def json(self):
113144
response = self.jwt_client.post(
114145
self.api_version_reverse("chat"), {"query": "Hello"}, format="json"
115146
)
147+
148+
mock_requests_post.assert_called_once()
149+
_, kwargs = mock_requests_post.call_args
150+
151+
headers = kwargs.get("headers", None)
152+
self.assertIsNotNone(headers)
153+
154+
mcp_headers_string = headers.get("MCP-HEADERS", None)
155+
self.assertIsNotNone(headers)
156+
157+
expected_mcp_headers = {
158+
"mcp::aap-controller": {"X-DAB-JW-TOKEN": self.encrypted_token},
159+
"mcp::aap-lightspeed": {"X-DAB-JW-TOKEN": self.encrypted_token},
160+
}
161+
mcp_headers = json.loads(mcp_headers_string)
162+
163+
self.assertDictEqual(mcp_headers, expected_mcp_headers)
164+
116165
self.assertEqual(response.status_code, HTTPStatus.OK)
117166
self.assertDictEqual(response.data, expected_response)

ansible_ai_connect/ai/api/views.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from rest_framework import permissions, serializers
3030
from rest_framework import status as rest_framework_status
3131
from rest_framework.generics import GenericAPIView
32+
from rest_framework.request import Request
3233
from rest_framework.response import Response
3334
from rest_framework.views import APIView
3435

@@ -108,6 +109,7 @@
108109
from ..feature_flags import FeatureFlags
109110
from .data.data_model import ContentMatchPayloadData, ContentMatchResponseDto
110111
from .model_pipelines.exceptions import ModelTimeoutError
112+
from .model_pipelines.http.configuration import HttpConfiguration
111113
from .permissions import (
112114
BlockUserWithoutSeat,
113115
BlockUserWithoutSeatAndWCAReadyOrg,
@@ -265,6 +267,27 @@ def finalize_response(self, request, response, *args, **kwargs):
265267
send_schema1_event(self.event)
266268
return response
267269

270+
@staticmethod
271+
def get_mcp_headers(request: Request, config: HttpConfiguration) -> dict:
272+
mcp_headers = {}
273+
jwt_header_name = "X-DAB-JW-TOKEN"
274+
token = request.headers.get(jwt_header_name, None)
275+
user = request.user
276+
if token and user.is_authenticated and user.aap_user:
277+
for mcp_server in config.mcp_servers:
278+
if mcp_server["type"] in ["controller", "eda", "hub", "lightspeed"]:
279+
mcp_headers[mcp_server["name"]] = {jwt_header_name: token}
280+
# This functionality seems experimental for gateway and does not allow the user to
281+
# access wide range of api endpoints, we need to find a solution for gateway,
282+
# but for the moment comment this code.
283+
# elif mcp_server["type"] == "gateway":
284+
# from ansible_base.resource_registry.resource_server import get_service_token
285+
# user_ansible_id = str(user.ansible_id_for_filter)
286+
# token = get_service_token(user_id=user_ansible_id, expiration=3600)
287+
# mcp_headers[mcp_server["name"]] = {"X-ANSIBLE-SERVICE-AUTH": token}
288+
289+
return mcp_headers
290+
268291

269292
class Completions(APIView):
270293
"""
@@ -1106,13 +1129,16 @@ def post(self, request) -> Response:
11061129
self.event.conversation_id = conversation_id
11071130
self.event.modelName = self.req_model_id or self.llm.config.model_id
11081131

1132+
mcp_headers = self.get_mcp_headers(request, self.llm.config)
1133+
11091134
data = self.llm.invoke(
11101135
ChatBotParameters.init(
11111136
query=req_query,
11121137
system_prompt=req_system_prompt,
11131138
model_id=self.req_model_id or self.llm.config.model_id,
11141139
provider=req_provider,
11151140
conversation_id=conversation_id,
1141+
mcp_headers=mcp_headers,
11161142
)
11171143
)
11181144

@@ -1122,10 +1148,10 @@ def post(self, request) -> Response:
11221148
raise ChatbotInvalidResponseException()
11231149

11241150
# Finalise Segment Event with response details
1125-
self.event.chat_truncated = bool(data["truncated"])
1151+
self.event.chat_truncated = bool(data.get("truncated", False))
11261152
self.event.chat_referenced_documents = [
11271153
ChatBotResponseDocsReferences(docs_url=rd["docs_url"], title=rd["title"])
1128-
for rd in data["referenced_documents"]
1154+
for rd in data.get("referenced_documents", [])
11291155
]
11301156
self.event.chat_response = anonymize_struct(data["response"])
11311157
self.event.chat_response = (
@@ -1203,6 +1229,8 @@ def post(self, request) -> StreamingHttpResponse:
12031229
self.event.conversation_id = conversation_id
12041230
self.event.modelName = self.req_model_id or self.llm.config.model_id
12051231

1232+
mcp_headers = self.get_mcp_headers(request, self.llm.config)
1233+
12061234
return self.llm.invoke(
12071235
StreamingChatBotParameters.init(
12081236
query=req_query,
@@ -1211,5 +1239,6 @@ def post(self, request) -> StreamingHttpResponse:
12111239
provider=req_provider,
12121240
conversation_id=conversation_id,
12131241
media_type=media_type,
1242+
mcp_headers=mcp_headers,
12141243
)
12151244
)

ansible_ai_connect/main/settings/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def is_ssl_enabled(value: str) -> bool:
595595
# ------------------------------------------
596596
CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER")
597597
CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true"
598-
LIGHTSPEED_TOOL_GROUP_TOKEN = os.getenv("LIGHTSPEED_TOOL_GROUP_TOKEN")
598+
CHATBOT_DEFAULT_SYSTEM_PROMPT = os.getenv("CHATBOT_DEFAULT_SYSTEM_PROMPT")
599599
# ==========================================
600600

601601
# ==========================================

tools/docker-compose/compose.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ services:
7676
- CHATBOT_URL=${CHATBOT_URL}
7777
- CHATBOT_DEFAULT_PROVIDER=${CHATBOT_DEFAULT_PROVIDER}
7878
- CHATBOT_DEFAULT_MODEL=${CHATBOT_DEFAULT_MODEL}
79+
- CHATBOT_DEFAULT_SYSTEM_PROMPT=${CHATBOT_DEFAULT_SYSTEM_PROMPT}
7980
- ANSIBLE_AI_MODEL_MESH_CONFIG=${ANSIBLE_AI_MODEL_MESH_CONFIG}
8081
- ANSIBLE_AI_ENABLE_ROLE_GEN_ENDPOINT=${ANSIBLE_AI_ENABLE_ROLE_GEN_ENDPOINT}
8182
- AAP_API_URL=${AAP_API_URL}

0 commit comments

Comments
 (0)