1515)
1616
1717from ..types import (
18- PermissionResult ,
1918 PermissionResultAllow ,
2019 PermissionResultDeny ,
2120 SDKControlPermissionRequest ,
@@ -48,7 +47,8 @@ def __init__(
4847 transport : Transport ,
4948 is_streaming_mode : bool ,
5049 can_use_tool : Callable [
51- [str , dict [str , Any ], dict [str , Any ]], Awaitable [dict [str , Any ]]
50+ [str , dict [str , Any ], ToolPermissionContext ],
51+ Awaitable [PermissionResultAllow | PermissionResultDeny ],
5252 ]
5353 | None = None ,
5454 hooks : dict [str , list [dict [str , Any ]]] | None = None ,
@@ -191,7 +191,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
191191 subtype = request_data ["subtype" ]
192192
193193 try :
194- response_data = {}
194+ response_data : dict [ str , Any ] = {}
195195
196196 if subtype == "can_use_tool" :
197197 permission_request : SDKControlPermissionRequest = request_data # type: ignore[assignment]
@@ -202,30 +202,28 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
202202 context = ToolPermissionContext (
203203 signal = None , # TODO: Add abort signal support
204204 suggestions = permission_request .get ("permission_suggestions" , [])
205+ or [],
205206 )
206207
207208 response = await self .can_use_tool (
208209 permission_request ["tool_name" ],
209210 permission_request ["input" ],
210- context
211+ context ,
211212 )
212213
213214 # Convert PermissionResult to expected dict format
214215 if isinstance (response , PermissionResultAllow ):
215- response_data = {
216- "allow" : True
217- }
218- if response .updatedInput is not None :
219- response_data ["input" ] = response .updatedInput
216+ response_data = {"allow" : True }
217+ if response .updated_input is not None :
218+ response_data ["input" ] = response .updated_input
220219 # TODO: Handle updatedPermissions when control protocol supports it
221220 elif isinstance (response , PermissionResultDeny ):
222- response_data = {
223- "allow" : False ,
224- "reason" : response .message
225- }
221+ response_data = {"allow" : False , "reason" : response .message }
226222 # TODO: Handle interrupt flag when control protocol supports it
227223 else :
228- raise TypeError (f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got { type (response )} " )
224+ raise TypeError (
225+ f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got { type (response )} "
226+ )
229227
230228 elif subtype == "hook_callback" :
231229 hook_callback_request : SDKHookCallbackRequest = request_data # type: ignore[assignment]
@@ -241,15 +239,20 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
241239 {"signal" : None }, # TODO: Add abort signal support
242240 )
243241
244- elif subtype == "mcp_request " :
242+ elif subtype == "mcp_message " :
245243 # Handle SDK MCP request
246244 server_name = request_data .get ("server_name" )
247245 mcp_message = request_data .get ("message" )
248246
249247 if not server_name or not mcp_message :
250248 raise Exception ("Missing server_name or message for MCP request" )
251249
252- response_data = await self ._handle_sdk_mcp_request (server_name , mcp_message )
250+ # Type narrowing - we've verified these are not None above
251+ assert isinstance (server_name , str )
252+ assert isinstance (mcp_message , dict )
253+ response_data = await self ._handle_sdk_mcp_request (
254+ server_name , mcp_message
255+ )
253256
254257 else :
255258 raise Exception (f"Unsupported control request subtype: { subtype } " )
@@ -317,7 +320,9 @@ async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]
317320 self .pending_control_results .pop (request_id , None )
318321 raise Exception (f"Control request timeout: { request .get ('subtype' )} " ) from e
319322
320- async def _handle_sdk_mcp_request (self , server_name : str , message : dict ) -> dict :
323+ async def _handle_sdk_mcp_request (
324+ self , server_name : str , message : dict [str , Any ]
325+ ) -> dict [str , Any ]:
321326 """Handle an MCP request for an SDK server.
322327
323328 This acts as a bridge between JSONRPC messages from the CLI
@@ -362,43 +367,50 @@ async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict
362367 {
363368 "name" : tool .name ,
364369 "description" : tool .description ,
365- "inputSchema" : tool .inputSchema .model_dump () if tool .inputSchema else {}
370+ "inputSchema" : tool .inputSchema .model_dump () # type: ignore[union-attr]
371+ if tool .inputSchema
372+ else {},
366373 }
367- for tool in result .root .tools
374+ for tool in result .root .tools # type: ignore[union-attr]
368375 ]
369376 return {
370377 "jsonrpc" : "2.0" ,
371378 "id" : message .get ("id" ),
372- "result" : {"tools" : tools_data }
379+ "result" : {"tools" : tools_data },
373380 }
374381
375382 elif method == "tools/call" :
376- request = CallToolRequest (
383+ call_request = CallToolRequest (
377384 method = method ,
378385 params = CallToolRequestParams (
379- name = params .get ("name" ),
380- arguments = params .get ("arguments" , {})
381- )
386+ name = params .get ("name" ), arguments = params .get ("arguments" , {})
387+ ),
382388 )
383389 handler = server .request_handlers .get (CallToolRequest )
384390 if handler :
385- result = await handler (request )
391+ result = await handler (call_request )
386392 # Convert MCP result to JSONRPC response
387393 content = []
388- for item in result .root .content :
389- if hasattr (item , ' text' ):
394+ for item in result .root .content : # type: ignore[union-attr]
395+ if hasattr (item , " text" ):
390396 content .append ({"type" : "text" , "text" : item .text })
391- elif hasattr (item , 'data' ) and hasattr (item , 'mimeType' ):
392- content .append ({"type" : "image" , "data" : item .data , "mimeType" : item .mimeType })
397+ elif hasattr (item , "data" ) and hasattr (item , "mimeType" ):
398+ content .append (
399+ {
400+ "type" : "image" ,
401+ "data" : item .data ,
402+ "mimeType" : item .mimeType ,
403+ }
404+ )
393405
394406 response_data = {"content" : content }
395- if hasattr (result .root , ' is_error' ) and result .root .is_error :
396- response_data ["is_error" ] = True
407+ if hasattr (result .root , " is_error" ) and result .root .is_error :
408+ response_data ["is_error" ] = True # type: ignore[assignment]
397409
398410 return {
399411 "jsonrpc" : "2.0" ,
400412 "id" : message .get ("id" ),
401- "result" : response_data
413+ "result" : response_data ,
402414 }
403415
404416 # Add more methods here as MCP SDK adds them (resources, prompts, etc.)
0 commit comments