66import functools
77import inspect
88import logging
9+ import sys
10+ import types
911import uuid
10- from typing import Any
12+ from typing import Annotated , Any , Union , cast , get_args , get_origin , get_type_hints
1113
1214from mcp .server .fastmcp import FastMCP , Context as MCPContext
1315from mcp .types import ToolAnnotations
2325_APPROVAL_WAIT_MESSAGE = "Waiting for reviewer approval…"
2426
2527
26- def _ensure_schwab_context (func : ToolFn ) -> ToolFn :
28+ def _is_context_annotation (annotation : Any ) -> bool :
29+ if annotation in (inspect ._empty , None ):
30+ return False
31+ if annotation is SchwabContext :
32+ return True
33+ if annotation == "SchwabContext" :
34+ return True
35+ if isinstance (annotation , str ):
36+ return annotation == "SchwabContext"
37+
38+ origin = get_origin (annotation )
39+ if origin is None :
40+ return False
41+
42+ if origin in (Annotated ,):
43+ args = get_args (annotation )
44+ return bool (args ) and _is_context_annotation (args [0 ])
45+
46+ if origin in (Union , types .UnionType ):
47+ return any (_is_context_annotation (arg ) for arg in get_args (annotation ))
48+
49+ return False
50+
51+
52+ def _resolve_context_parameters (func : ToolFn ) -> tuple [inspect .Signature , list [str ]]:
2753 signature = inspect .signature (func )
28- ctx_params = [
29- name
30- for name , param in signature .parameters .items ()
31- if param .annotation is SchwabContext
32- ]
54+
55+ module = sys .modules .get (func .__module__ )
56+ globalns = vars (module ) if module else {}
57+
58+ type_hints : dict [str , Any ]
59+ try :
60+ type_hints = get_type_hints (func , globalns = globalns , include_extras = True )
61+ except TypeError :
62+ type_hints = get_type_hints (func , globalns = globalns )
63+ except Exception :
64+ type_hints = {}
65+
66+ ctx_params = []
67+ for name , param in signature .parameters .items ():
68+ annotation = type_hints .get (name , param .annotation )
69+ if _is_context_annotation (annotation ):
70+ ctx_params .append (name )
71+
72+ return signature , ctx_params
73+
74+
75+ def _ensure_schwab_context (func : ToolFn ) -> ToolFn :
76+ signature , ctx_params = _resolve_context_parameters (func )
3377 if not ctx_params :
3478 return func
3579
@@ -57,6 +101,15 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
57101 return await result
58102 return result
59103
104+ # Ensure annotations referencing names from the original module remain resolvable.
105+ wrapper_globals = cast (dict [str , Any ], getattr (wrapper , "__globals__" , {}))
106+ module = inspect .getmodule (func )
107+ if module is not None :
108+ module_globals = vars (module )
109+ if wrapper_globals is not module_globals :
110+ for key , value in module_globals .items ():
111+ wrapper_globals .setdefault (key , value )
112+
60113 return wrapper
61114
62115
@@ -68,12 +121,7 @@ def _format_argument(value: Any) -> str:
68121
69122
70123def _wrap_with_approval (func : ToolFn ) -> ToolFn :
71- signature = inspect .signature (func )
72- ctx_params = [
73- name
74- for name , param in signature .parameters .items ()
75- if param .annotation is SchwabContext
76- ]
124+ signature , ctx_params = _resolve_context_parameters (func )
77125 if not ctx_params :
78126 raise TypeError (
79127 f"Write tool '{ func .__name__ } ' must accept a SchwabContext parameter for approval gating."
@@ -160,6 +208,14 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
160208 raise PermissionError (message )
161209 raise TimeoutError (message )
162210
211+ wrapper_globals = cast (dict [str , Any ], getattr (wrapper , "__globals__" , {}))
212+ module = inspect .getmodule (func )
213+ if module is not None :
214+ module_globals = vars (module )
215+ if wrapper_globals is not module_globals :
216+ for key , value in module_globals .items ():
217+ wrapper_globals .setdefault (key , value )
218+
163219 return wrapper
164220
165221
0 commit comments