Skip to content

Commit c736daa

Browse files
committed
fixed code to correctly execute e2e test for MCP
1 parent 4b358ab commit c736daa

File tree

14 files changed

+784
-310
lines changed

14 files changed

+784
-310
lines changed

dev-tools/mcp-mock-server/server.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ def _capture_headers(self) -> None:
6060
if len(request_log) > 10:
6161
request_log.pop(0)
6262

63-
def do_POST(self) -> None: # pylint: disable=invalid-name
63+
def do_POST(
64+
self,
65+
) -> (
66+
None
67+
): # pylint: disable=invalid-name,too-many-locals,too-many-branches,too-many-statements
6468
"""Handle POST requests (MCP protocol endpoints)."""
6569
self._capture_headers()
6670

@@ -77,23 +81,40 @@ def do_POST(self) -> None: # pylint: disable=invalid-name
7781
request_id = 1
7882
method = "unknown"
7983

84+
# Log the RPC method in the request log
85+
if request_log:
86+
request_log[-1]["rpc_method"] = method
87+
8088
# Determine tool name based on authorization header to avoid collisions
8189
auth_header = self.headers.get("Authorization", "")
8290

8391
# Initialize tool info defaults
8492
tool_name = "mock_tool_no_auth"
8593
tool_desc = "Mock tool with no authorization"
94+
error_mode = False
8695

8796
# Match based on token content
88-
if "test-secret-token" in auth_header:
89-
tool_name = "mock_tool_file"
90-
tool_desc = "Mock tool with file-based auth"
91-
elif "my-k8s-token" in auth_header:
92-
tool_name = "mock_tool_k8s"
93-
tool_desc = "Mock tool with Kubernetes token"
94-
elif "my-client-token" in auth_header:
95-
tool_name = "mock_tool_client"
96-
tool_desc = "Mock tool with client-provided token"
97+
match True:
98+
case _ if "test-secret-token" in auth_header:
99+
tool_name = "mock_tool_file"
100+
tool_desc = "Mock tool with file-based auth"
101+
case _ if "my-k8s-token" in auth_header:
102+
tool_name = "mock_tool_k8s"
103+
tool_desc = "Mock tool with Kubernetes token"
104+
case _ if "my-client-token" in auth_header:
105+
tool_name = "mock_tool_client"
106+
tool_desc = "Mock tool with client-provided token"
107+
case _ if "error-mode" in auth_header:
108+
tool_name = "mock_tool_error"
109+
tool_desc = "Mock tool configured to return errors"
110+
error_mode = True
111+
case _:
112+
# Default case already set above
113+
pass
114+
115+
# Log the tool name in the request log
116+
if request_log:
117+
request_log[-1]["tool_name"] = tool_name
97118

98119
# Handle MCP protocol methods using match statement
99120
response: dict = {}
@@ -145,29 +166,46 @@ def do_POST(self) -> None: # pylint: disable=invalid-name
145166
tool_called = params.get("name", "unknown")
146167
arguments = params.get("arguments", {})
147168

148-
# Build result text
149-
auth_preview = (
150-
auth_header[:50] if len(auth_header) > 50 else auth_header
151-
)
152-
result_text = (
153-
f"Mock tool '{tool_called}' executed successfully "
154-
f"with arguments: {arguments}. Auth used: {auth_preview}..."
155-
)
156-
157-
# Return successful tool execution result
158-
response = {
159-
"jsonrpc": "2.0",
160-
"id": request_id,
161-
"result": {
162-
"content": [
163-
{
164-
"type": "text",
165-
"text": result_text,
166-
}
167-
],
168-
"isError": False,
169-
},
170-
}
169+
# Check if error mode is enabled
170+
if error_mode:
171+
# Return error response
172+
response = {
173+
"jsonrpc": "2.0",
174+
"id": request_id,
175+
"result": {
176+
"content": [
177+
{
178+
"type": "text",
179+
"text": (
180+
f"Error: Tool '{tool_called}' "
181+
"execution failed - simulated error."
182+
),
183+
}
184+
],
185+
"isError": True,
186+
},
187+
}
188+
else:
189+
# Build result text
190+
result_text = (
191+
f"Mock tool '{tool_called}' executed successfully "
192+
f"with arguments: {arguments}."
193+
)
194+
195+
# Return successful tool execution result
196+
response = {
197+
"jsonrpc": "2.0",
198+
"id": request_id,
199+
"result": {
200+
"content": [
201+
{
202+
"type": "text",
203+
"text": result_text,
204+
}
205+
],
206+
"isError": False,
207+
},
208+
}
171209

172210
case _:
173211
# Generic success response for other methods
@@ -194,6 +232,10 @@ def do_GET(self) -> None: # pylint: disable=invalid-name
194232
)
195233
case "/debug/requests":
196234
self._send_json_response(request_log)
235+
case "/debug/clear":
236+
# Clear the request log
237+
request_log.clear()
238+
self._send_json_response({"status": "cleared", "request_count": 0})
197239
case "/":
198240
self._send_help_page()
199241
case _:
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Lightspeed Core Service - MCP Mock Server Test (Noop Auth)
2+
service:
3+
host: localhost
4+
port: 8080
5+
auth_enabled: false
6+
workers: 1
7+
color_log: true
8+
access_log: true
9+
llama_stack:
10+
use_as_library_client: true
11+
library_client_config_path: "dev-tools/test-configs/llama-stack-mcp-test.yaml"
12+
user_data_collection:
13+
feedback_enabled: false
14+
transcripts_enabled: false
15+
authentication:
16+
module: "noop"
17+
inference:
18+
default_model: "gpt-4o-mini"
19+
default_provider: "openai"
20+
mcp_servers:
21+
# Test 1: Static file-based authentication (HTTP)
22+
- name: "mock-file-auth"
23+
provider_id: "model-context-protocol"
24+
url: "http://localhost:9000"
25+
authorization_headers:
26+
Authorization: "/tmp/lightspeed-mcp-test-token"
27+
# Test 2: Kubernetes token forwarding (HTTP)
28+
- name: "mock-k8s-auth"
29+
provider_id: "model-context-protocol"
30+
url: "http://localhost:9000"
31+
authorization_headers:
32+
Authorization: "kubernetes"
33+
# Test 3: Client-provided token (HTTP - simplified for testing)
34+
- name: "mock-client-auth"
35+
provider_id: "model-context-protocol"
36+
url: "http://localhost:9000"
37+
authorization_headers:
38+
Authorization: "client"

docker-compose.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ services:
8484
- TENANT_ID=${TENANT_ID:-}
8585
- CLIENT_ID=${CLIENT_ID:-}
8686
- CLIENT_SECRET=${CLIENT_SECRET:-}
87+
entrypoint: >
88+
/bin/bash -c "
89+
echo 'test-secret-token-123' > /tmp/lightspeed-mcp-test-token &&
90+
/opt/app-root/src/scripts/run.sh
91+
"
8792
depends_on:
8893
llama-stack:
8994
condition: service_healthy

src/app/endpoints/query.py

Lines changed: 122 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Handler for REST API call to provide answer to query."""
22

3+
import asyncio
34
import ast
45
import logging
56
import re
@@ -77,6 +78,33 @@
7778
503: ServiceUnavailableResponse.openapi_response(),
7879
}
7980

81+
# Track background tasks to prevent garbage collection
82+
# Background tasks created with asyncio.create_task() need strong references
83+
# to prevent premature garbage collection before they complete
84+
background_tasks_set: set[asyncio.Task] = set()
85+
86+
87+
def create_background_task(coro: Any) -> None:
88+
"""Create a background task and track it to prevent garbage collection.
89+
90+
This function creates a detached async task that runs independently of the
91+
HTTP request lifecycle. Tasks are stored in a module-level set to maintain
92+
strong references, preventing garbage collection. When a task completes,
93+
it automatically removes itself from the set.
94+
95+
Args:
96+
coro: Coroutine to run as a background task
97+
"""
98+
try:
99+
task = asyncio.create_task(coro)
100+
background_tasks_set.add(task)
101+
task.add_done_callback(background_tasks_set.discard)
102+
logger.debug(
103+
f"Background task created, active tasks: {len(background_tasks_set)}"
104+
)
105+
except Exception as e:
106+
logger.error(f"Failed to create background task: {e}", exc_info=True)
107+
80108

81109
def is_transcripts_enabled() -> bool:
82110
"""Check if transcripts is enabled.
@@ -297,26 +325,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
297325
)
298326
)
299327

300-
# Get the initial topic summary for the conversation
301-
topic_summary = None
302-
with get_session() as session:
303-
existing_conversation = (
304-
session.query(UserConversation).filter_by(id=conversation_id).first()
305-
)
306-
if not existing_conversation:
307-
# Check if topic summary should be generated (default: True)
308-
should_generate = query_request.generate_topic_summary
309-
310-
if should_generate:
311-
logger.debug("Generating topic summary for new conversation")
312-
topic_summary = await get_topic_summary_func(
313-
query_request.query, client, llama_stack_model_id
314-
)
315-
else:
316-
logger.debug(
317-
"Topic summary generation disabled by request parameter"
318-
)
319-
topic_summary = None
320328
# Convert RAG chunks to dictionary format once for reuse
321329
logger.info("Processing RAG chunks...")
322330
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]
@@ -338,15 +346,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
338346
attachments=query_request.attachments or [],
339347
)
340348

341-
logger.info("Persisting conversation details...")
342-
persist_user_conversation_details(
343-
user_id=user_id,
344-
conversation_id=conversation_id,
345-
model=model_id,
346-
provider_id=provider_id,
347-
topic_summary=topic_summary,
348-
)
349-
350349
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
351350
cache_entry = CacheEntry(
352351
query=query_request.query,
@@ -376,15 +375,20 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
376375
conversation_id,
377376
cache_entry,
378377
_skip_userid_check,
379-
topic_summary,
378+
None, # topic_summary is generated in background task
380379
)
381380

382381
# Convert tool calls to response format
383382
logger.info("Processing tool calls...")
384383

385384
logger.info("Using referenced documents from response...")
386385

387-
available_quotas = get_available_quotas(configuration.quota_limiters, user_id)
386+
# Get available quotas if quota limiters are configured
387+
available_quotas = {}
388+
if configuration.quota_limiters:
389+
available_quotas = get_available_quotas(
390+
configuration.quota_limiters, user_id
391+
)
388392

389393
logger.info("Building final response...")
390394
response = QueryResponse(
@@ -399,10 +403,95 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
399403
output_tokens=token_usage.output_tokens,
400404
available_quotas=available_quotas,
401405
)
406+
407+
# Schedule conversation persistence as a detached background task
408+
# IMPORTANT: We use asyncio.create_task() instead of FastAPI's BackgroundTasks
409+
# for two critical reasons:
410+
# 1. Complete detachment from request context: The task runs independently,
411+
# not tied to the HTTP request lifecycle or middleware processing
412+
# 2. MCP session lifecycle compatibility: Llama Stack's MCPSessionManager.close_all()
413+
# aggressively cancels tasks within the request context. By creating a detached
414+
# task, we avoid this cancellation scope entirely.
415+
async def persist_with_topic_summary() -> None:
416+
"""Persist conversation with topic summary generation.
417+
418+
This function runs as a background task AFTER the HTTP response has been sent.
419+
420+
Strategy for MCP compatibility and database isolation:
421+
1. Wait 500ms for MCP session cleanup to complete naturally
422+
2. Then safely call LLM for topic summary generation without cancellation
423+
3. Use independent database sessions in thread pool to avoid connection issues
424+
4. Persist conversation details with or without topic summary
425+
426+
The delay ensures MCPSessionManager.close_all() has finished its cleanup
427+
before we make any new LLM calls, preventing CancelledError exceptions.
428+
Database operations run in thread pool to isolate from request lifecycle.
429+
"""
430+
logger.debug("Background task: waiting for MCP cleanup")
431+
# Give MCP sessions time to clean up (they close after response is sent)
432+
await asyncio.sleep(0.5) # 500ms should be enough for cleanup
433+
logger.debug("Background task: MCP cleanup complete")
434+
435+
topic_summary = None
436+
should_generate = (
437+
query_request.generate_topic_summary
438+
if query_request.generate_topic_summary is not None
439+
else True
440+
)
441+
442+
# Check if this is a new conversation and generate topic summary if needed
443+
if should_generate:
444+
try:
445+
446+
def check_conversation_exists() -> bool:
447+
"""Check if conversation exists in database (runs in thread pool)."""
448+
with get_session() as session:
449+
existing = (
450+
session.query(UserConversation)
451+
.filter_by(id=conversation_id)
452+
.first()
453+
)
454+
return existing is not None
455+
456+
# Run database check in thread pool to avoid connection issues
457+
conversation_exists = await asyncio.to_thread(
458+
check_conversation_exists
459+
)
460+
461+
if not conversation_exists:
462+
logger.debug("Generating topic summary for new conversation")
463+
topic_summary = await get_topic_summary_func(
464+
query_request.query, client, llama_stack_model_id
465+
)
466+
logger.info("Topic summary generated successfully")
467+
except Exception as e: # pylint: disable=broad-exception-caught
468+
logger.error("Failed to generate topic summary: %s", e)
469+
topic_summary = None
470+
471+
# Persist conversation
472+
try:
473+
474+
def persist_conversation() -> None:
475+
"""Persist conversation to database (runs in thread pool)."""
476+
persist_user_conversation_details(
477+
user_id=user_id,
478+
conversation_id=conversation_id,
479+
model=model_id,
480+
provider_id=provider_id,
481+
topic_summary=topic_summary,
482+
)
483+
484+
# Run persistence in thread pool to avoid connection issues
485+
await asyncio.to_thread(persist_conversation)
486+
logger.debug("Conversation persisted successfully")
487+
except Exception as e: # pylint: disable=broad-exception-caught
488+
logger.error("Failed to persist conversation: %s", e)
489+
490+
# Create detached task with strong reference to prevent garbage collection
491+
create_background_task(persist_with_topic_summary())
492+
402493
logger.info("Query processing completed successfully!")
403494
return response
404-
405-
# connection to Llama Stack server
406495
except APIConnectionError as e:
407496
# Update metrics for the LLM call failure
408497
metrics.llm_calls_failures_total.inc()

0 commit comments

Comments
 (0)