Skip to content

Commit d9011a3

Browse files
committed
Claude and I finished the job. No Unknown types any more and test actually call out to openai
1 parent 7106267 commit d9011a3

File tree

4 files changed

+97
-38
lines changed

4 files changed

+97
-38
lines changed

test/test_mcp_server.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,69 @@
33

44
"""End-to-end tests for the MCP server."""
55

6+
from typing import Any
7+
68
import pytest
79

8-
from mcp.types import TextContent
10+
from mcp.client.session import ClientSession as ClientSessionType
11+
from mcp.shared.context import RequestContext
12+
from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent
913

1014
from fixtures import really_needs_auth
1115

1216

17+
async def sampling_callback(
18+
context: RequestContext[ClientSessionType, Any, Any],
19+
params: CreateMessageRequestParams,
20+
) -> CreateMessageResult:
21+
"""Sampling callback that uses OpenAI to generate responses."""
22+
# Use OpenAI to generate a response
23+
import openai
24+
from openai.types.chat import ChatCompletionMessageParam
25+
26+
client = openai.AsyncOpenAI()
27+
28+
# Convert MCP SamplingMessage to OpenAI format
29+
messages: list[ChatCompletionMessageParam] = []
30+
for msg in params.messages:
31+
# Handle TextContent
32+
content: str
33+
if isinstance(msg.content, TextContent):
34+
content = msg.content.text
35+
elif hasattr(msg.content, "text"):
36+
content = msg.content.text # type: ignore
37+
else:
38+
content = str(msg.content)
39+
40+
# MCP roles are "user" or "assistant", which are compatible with OpenAI
41+
if msg.role == "user":
42+
messages.append({"role": "user", "content": content})
43+
else:
44+
messages.append({"role": "assistant", "content": content})
45+
46+
# Add system prompt if provided
47+
if params.systemPrompt:
48+
messages.insert(0, {"role": "system", "content": params.systemPrompt})
49+
50+
# Call OpenAI
51+
response = await client.chat.completions.create(
52+
model="gpt-4o-mini",
53+
messages=messages,
54+
max_tokens=params.maxTokens,
55+
temperature=params.temperature if params.temperature is not None else 1.0,
56+
)
57+
58+
# Convert response to MCP format
59+
return CreateMessageResult(
60+
role="assistant",
61+
content=TextContent(
62+
type="text", text=response.choices[0].message.content or ""
63+
),
64+
model=response.model,
65+
stopReason="endTurn",
66+
)
67+
68+
1369
@pytest.mark.asyncio
1470
async def test_mcp_server_query_conversation(really_needs_auth):
1571
"""Test the query_conversation tool end-to-end using MCP client."""
@@ -25,7 +81,9 @@ async def test_mcp_server_query_conversation(really_needs_auth):
2581

2682
# Create client session and connect to server
2783
async with stdio_client(server_params) as (read, write):
28-
async with ClientSession(read, write) as session:
84+
async with ClientSession(
85+
read, write, sampling_callback=sampling_callback
86+
) as session:
2987
# Initialize the session
3088
await session.initialize()
3189

@@ -49,9 +107,7 @@ async def test_mcp_server_query_conversation(really_needs_auth):
49107

50108
# Type narrow the content to TextContent
51109
content_item = result.content[0]
52-
assert isinstance(
53-
content_item, TextContent
54-
), f"Expected TextContent, got {type(content_item)}"
110+
assert isinstance(content_item, TextContent)
55111
response_text = content_item.text
56112

57113
# Parse response (it should be JSON with success, answer, time_used)
@@ -68,7 +124,7 @@ async def test_mcp_server_query_conversation(really_needs_auth):
68124

69125

70126
@pytest.mark.asyncio
71-
async def test_mcp_server_empty_question(really_needs_auth):
127+
async def test_mcp_server_empty_question():
72128
"""Test the query_conversation tool with an empty question."""
73129
from mcp import ClientSession, StdioServerParameters
74130
from mcp.client.stdio import stdio_client
@@ -82,7 +138,9 @@ async def test_mcp_server_empty_question(really_needs_auth):
82138

83139
# Create client session and connect to server
84140
async with stdio_client(server_params) as (read, write):
85-
async with ClientSession(read, write) as session:
141+
async with ClientSession(
142+
read, write, sampling_callback=sampling_callback
143+
) as session:
86144
# Initialize the session
87145
await session.initialize()
88146

typeagent/mcp/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

typeagent/mcp/__main__.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

typeagent/mcp/server.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,24 @@
77
import time
88
from typing import Any
99

10-
from mcp.server.fastmcp import FastMCP
11-
from mcp.shared.context import RequestContext
12-
from mcp.types import TextContent
10+
from mcp.server.fastmcp import Context, FastMCP
11+
from mcp.server.session import ServerSession
12+
from mcp.types import SamplingMessage, TextContent
1313
import typechat
1414

1515
from typeagent.aitools import embeddings, utils
16-
from typeagent.aitools.embeddings import AsyncEmbeddingModel
1716
from typeagent.knowpro import answers, query, searchlang
1817
from typeagent.knowpro.convsettings import ConversationSettings
1918
from typeagent.knowpro.answer_response_schema import AnswerResponse
2019
from typeagent.knowpro.search_query_schema import SearchQuery
2120
from typeagent.podcasts import podcast
21+
from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex
2222

2323

2424
class MCPTypeChatModel(typechat.TypeChatLanguageModel):
2525
"""TypeChat language model that uses MCP sampling API instead of direct API calls."""
2626

27-
def __init__(self, session: Any):
27+
def __init__(self, session: ServerSession):
2828
"""Initialize with MCP session for sampling.
2929
3030
Args:
@@ -37,19 +37,29 @@ async def complete(
3737
) -> typechat.Result[str]:
3838
"""Request completion from the MCP client's LLM."""
3939
try:
40-
# Convert prompt to message format
40+
# Convert prompt to MCP SamplingMessage format
41+
sampling_messages: list[SamplingMessage]
4142
if isinstance(prompt, str):
42-
messages = [{"role": "user", "content": prompt}]
43+
sampling_messages = [
44+
SamplingMessage(
45+
role="user", content=TextContent(type="text", text=prompt)
46+
)
47+
]
4348
else:
44-
# PromptSection list: convert to messages
45-
messages = []
49+
# PromptSection list: convert to SamplingMessage
50+
sampling_messages = []
4651
for section in prompt:
4752
role = "user" if section["role"] == "user" else "assistant"
48-
messages.append({"role": role, "content": section["content"]})
53+
sampling_messages.append(
54+
SamplingMessage(
55+
role=role,
56+
content=TextContent(type="text", text=section["content"]),
57+
)
58+
)
4959

5060
# Use MCP sampling to request completion from client
5161
result = await self.session.create_message(
52-
messages=messages, max_tokens=16384
62+
messages=sampling_messages, max_tokens=16384
5363
)
5464

5565
# Extract text content from response
@@ -58,7 +68,7 @@ async def complete(
5868
return typechat.Success(result.content.text)
5969
elif isinstance(result.content, list):
6070
# Handle list of content items
61-
text_parts = []
71+
text_parts: list[str] = []
6272
for item in result.content:
6373
if isinstance(item, TextContent):
6474
text_parts.append(item.text)
@@ -77,19 +87,21 @@ async def complete(
7787
class ProcessingContext:
7888
lang_search_options: searchlang.LanguageSearchOptions
7989
answer_context_options: answers.AnswerContextOptions
80-
query_context: query.QueryEvalContext
90+
query_context: query.QueryEvalContext[
91+
podcast.PodcastMessage, TermToSemanticRefIndex
92+
]
8193
embedding_model: embeddings.AsyncEmbeddingModel
8294
query_translator: typechat.TypeChatJsonTranslator[SearchQuery]
8395
answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse]
8496

8597
def __repr__(self) -> str:
86-
parts = []
98+
parts: list[str] = []
8799
parts.append(f"{self.lang_search_options}")
88100
parts.append(f"{self.answer_context_options}")
89101
return f"Context({', '.join(parts)})"
90102

91103

92-
async def make_context(session: Any) -> ProcessingContext:
104+
async def make_context(session: ServerSession) -> ProcessingContext:
93105
"""Create processing context using MCP-based language model.
94106
95107
Args:
@@ -135,7 +147,7 @@ async def make_context(session: Any) -> ProcessingContext:
135147

136148
async def load_podcast_index(
137149
podcast_file_prefix: str, settings: ConversationSettings
138-
) -> query.QueryEvalContext:
150+
) -> query.QueryEvalContext[podcast.PodcastMessage, Any]:
139151
conversation = await podcast.Podcast.read_from_file(podcast_file_prefix, settings)
140152
assert (
141153
conversation is not None
@@ -155,7 +167,9 @@ class QuestionResponse:
155167

156168

157169
@mcp.tool()
158-
async def query_conversation(question: str, ctx: RequestContext) -> QuestionResponse:
170+
async def query_conversation(
171+
question: str, ctx: Context[ServerSession, Any, Any]
172+
) -> QuestionResponse:
159173
"""Send a question to the memory server and get an answer back"""
160174
t0 = time.time()
161175
question = question.strip()
@@ -164,7 +178,7 @@ async def query_conversation(question: str, ctx: RequestContext) -> QuestionResp
164178
return QuestionResponse(
165179
success=False, answer="No question provided", time_used=dt
166180
)
167-
context = await make_context(ctx.session)
181+
context = await make_context(ctx.request_context.session)
168182

169183
# Stages 1, 2, 3 (LLM -> proto-query, compile, execute query)
170184
result = await searchlang.search_conversation_with_language(

0 commit comments

Comments
 (0)