77import time
88from typing import Any
99
10- from mcp .server .fastmcp import FastMCP
11- from mcp .shared . context import RequestContext
12- from mcp .types import TextContent
10+ from mcp .server .fastmcp import Context , FastMCP
11+ from mcp .server . session import ServerSession
12+ from mcp .types import SamplingMessage , TextContent
1313import typechat
1414
1515from typeagent .aitools import embeddings , utils
16- from typeagent .aitools .embeddings import AsyncEmbeddingModel
1716from typeagent .knowpro import answers , query , searchlang
1817from typeagent .knowpro .convsettings import ConversationSettings
1918from typeagent .knowpro .answer_response_schema import AnswerResponse
2019from typeagent .knowpro .search_query_schema import SearchQuery
2120from typeagent .podcasts import podcast
21+ from typeagent .storage .memory .semrefindex import TermToSemanticRefIndex
2222
2323
2424class MCPTypeChatModel (typechat .TypeChatLanguageModel ):
2525 """TypeChat language model that uses MCP sampling API instead of direct API calls."""
2626
27- def __init__ (self , session : Any ):
27+ def __init__ (self , session : ServerSession ):
2828 """Initialize with MCP session for sampling.
2929
3030 Args:
@@ -37,19 +37,29 @@ async def complete(
3737 ) -> typechat .Result [str ]:
3838 """Request completion from the MCP client's LLM."""
3939 try :
40- # Convert prompt to message format
40+ # Convert prompt to MCP SamplingMessage format
41+ sampling_messages : list [SamplingMessage ]
4142 if isinstance (prompt , str ):
42- messages = [{"role" : "user" , "content" : prompt }]
43+ sampling_messages = [
44+ SamplingMessage (
45+ role = "user" , content = TextContent (type = "text" , text = prompt )
46+ )
47+ ]
4348 else :
44- # PromptSection list: convert to messages
45- messages = []
49+ # PromptSection list: convert to SamplingMessage
50+ sampling_messages = []
4651 for section in prompt :
4752 role = "user" if section ["role" ] == "user" else "assistant"
48- messages .append ({"role" : role , "content" : section ["content" ]})
53+ sampling_messages .append (
54+ SamplingMessage (
55+ role = role ,
56+ content = TextContent (type = "text" , text = section ["content" ]),
57+ )
58+ )
4959
5060 # Use MCP sampling to request completion from client
5161 result = await self .session .create_message (
52- messages = messages , max_tokens = 16384
62+ messages = sampling_messages , max_tokens = 16384
5363 )
5464
5565 # Extract text content from response
@@ -58,7 +68,7 @@ async def complete(
5868 return typechat .Success (result .content .text )
5969 elif isinstance (result .content , list ):
6070 # Handle list of content items
61- text_parts = []
71+ text_parts : list [ str ] = []
6272 for item in result .content :
6373 if isinstance (item , TextContent ):
6474 text_parts .append (item .text )
@@ -77,19 +87,21 @@ async def complete(
7787class ProcessingContext :
7888 lang_search_options : searchlang .LanguageSearchOptions
7989 answer_context_options : answers .AnswerContextOptions
80- query_context : query .QueryEvalContext
90+ query_context : query .QueryEvalContext [
91+ podcast .PodcastMessage , TermToSemanticRefIndex
92+ ]
8193 embedding_model : embeddings .AsyncEmbeddingModel
8294 query_translator : typechat .TypeChatJsonTranslator [SearchQuery ]
8395 answer_translator : typechat .TypeChatJsonTranslator [AnswerResponse ]
8496
8597 def __repr__ (self ) -> str :
86- parts = []
98+ parts : list [ str ] = []
8799 parts .append (f"{ self .lang_search_options } " )
88100 parts .append (f"{ self .answer_context_options } " )
89101 return f"Context({ ', ' .join (parts )} )"
90102
91103
92- async def make_context (session : Any ) -> ProcessingContext :
104+ async def make_context (session : ServerSession ) -> ProcessingContext :
93105 """Create processing context using MCP-based language model.
94106
95107 Args:
@@ -135,7 +147,7 @@ async def make_context(session: Any) -> ProcessingContext:
135147
136148async def load_podcast_index (
137149 podcast_file_prefix : str , settings : ConversationSettings
138- ) -> query .QueryEvalContext :
150+ ) -> query .QueryEvalContext [ podcast . PodcastMessage , Any ] :
139151 conversation = await podcast .Podcast .read_from_file (podcast_file_prefix , settings )
140152 assert (
141153 conversation is not None
@@ -155,7 +167,9 @@ class QuestionResponse:
155167
156168
157169@mcp .tool ()
158- async def query_conversation (question : str , ctx : RequestContext ) -> QuestionResponse :
170+ async def query_conversation (
171+ question : str , ctx : Context [ServerSession , Any , Any ]
172+ ) -> QuestionResponse :
159173 """Send a question to the memory server and get an answer back"""
160174 t0 = time .time ()
161175 question = question .strip ()
@@ -164,7 +178,7 @@ async def query_conversation(question: str, ctx: RequestContext) -> QuestionResp
164178 return QuestionResponse (
165179 success = False , answer = "No question provided" , time_used = dt
166180 )
167- context = await make_context (ctx .session )
181+ context = await make_context (ctx .request_context . session )
168182
169183 # Stages 1, 2, 3 (LLM -> proto-query, compile, execute query)
170184 result = await searchlang .search_conversation_with_language (
0 commit comments