Skip to content

Commit 236a42f

Browse files
authored
Merge pull request modelcontextprotocol#26 from jkoelker/jk/technicals
feat(tools): add technical tools
2 parents 0562e43 + c68e01b commit 236a42f

File tree

19 files changed

+2764
-30
lines changed

19 files changed

+2764
-30
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1+
---
12
name: CI
23

3-
on:
4+
"on":
45
push:
5-
branches: ["main"]
6+
branches:
7+
- main
68
pull_request:
79

810
jobs:
911
lint-and-test:
1012
runs-on: ubuntu-latest
1113
strategy:
1214
matrix:
13-
python-version: ["3.10", "3.11", "3.12"]
15+
python-version:
16+
- 3.10
17+
- 3.11
18+
- 3.12
1419
steps:
1520
- name: Checkout
1621
uses: actions/checkout@v4
@@ -21,7 +26,7 @@ jobs:
2126
python-version: ${{ matrix.python-version }}
2227

2328
- name: Sync dependencies
24-
run: uv sync --group dev
29+
run: uv sync --group dev --group ta
2530

2631
- name: Lint (ruff)
2732
run: uv run ruff check .
@@ -30,4 +35,4 @@ jobs:
3035
run: uv run pyright
3136

3237
- name: Run tests
33-
run: uv run python -m pytest
38+
run: uv run pytest

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ dev = [
3232
"pytest>=8.4.2",
3333
"ruff>=0.14.0",
3434
]
35+
ta = [
36+
"pandas>=2.2.2",
37+
"pandas-ta-classic>=0.3.36",
38+
]
3539

3640
[tool.pytest.ini_options]
3741
testpaths = ["tests"]

src/schwab_mcp/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def auth(
114114
is_flag=True,
115115
help="Allow tools to modify the portfolios, placing trades, etc.",
116116
)
117+
@click.option(
118+
"--no-technical-tools",
119+
default=False,
120+
is_flag=True,
121+
help="Disable optional technical analysis tools.",
122+
)
117123
@click.option(
118124
"--discord-token",
119125
type=str,
@@ -150,6 +156,7 @@ def server(
150156
discord_channel_id: int | None,
151157
discord_approver: tuple[str, ...],
152158
discord_timeout: int,
159+
no_technical_tools: bool,
153160
) -> int:
154161
"""Run the Schwab MCP server."""
155162
# No logging to stderr when in MCP mode (we'll use proper MCP responses)
@@ -258,6 +265,7 @@ def server(
258265
client,
259266
approval_manager=approval_manager,
260267
allow_write=allow_write,
268+
enable_technical_tools=not no_technical_tools,
261269
)
262270
anyio.run(server.run, backend="asyncio")
263271
return 0

src/schwab_mcp/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,18 @@ def __init__(
5555
approval_manager: ApprovalManager,
5656
*,
5757
allow_write: bool,
58+
enable_technical_tools: bool = True,
5859
) -> None:
5960
self._server = FastMCP(
6061
name=name,
6162
lifespan=_client_lifespan(client, approval_manager),
6263
)
63-
register_tools(self._server, client, allow_write=allow_write)
64+
register_tools(
65+
self._server,
66+
client,
67+
allow_write=allow_write,
68+
enable_technical=enable_technical_tools,
69+
)
6470

6571
async def run(self) -> None:
6672
"""Run the server using FastMCP's stdio transport."""

src/schwab_mcp/tools/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import logging
4+
35
from mcp.server.fastmcp import FastMCP
46
from schwab.client import AsyncClient
57

@@ -9,8 +11,11 @@
911
from schwab_mcp.tools import orders as _orders
1012
from schwab_mcp.tools import quotes as _quotes
1113
from schwab_mcp.tools import tools as _tools
14+
from schwab_mcp.tools import technical as _technical
1215
from schwab_mcp.tools import transactions as _txns
1316

17+
logger = logging.getLogger(__name__)
18+
1419
_TOOL_MODULES = (
1520
_tools,
1621
_account,
@@ -22,11 +27,21 @@
2227
)
2328

2429

25-
def register_tools(server: FastMCP, client: AsyncClient, *, allow_write: bool) -> None:
30+
def register_tools(
31+
server: FastMCP,
32+
client: AsyncClient,
33+
*,
34+
allow_write: bool,
35+
enable_technical: bool = True,
36+
) -> None:
2637
"""Register all Schwab tools with the provided FastMCP server."""
2738
_ = client
2839

29-
for module in _TOOL_MODULES:
40+
modules = _TOOL_MODULES
41+
if enable_technical:
42+
modules = modules + (_technical,)
43+
44+
for module in modules:
3045
register_module = getattr(module, "register", None)
3146
if register_module is None:
3247
raise AttributeError(f"Tool module {module.__name__} missing register()")

src/schwab_mcp/tools/_registration.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import functools
77
import inspect
88
import logging
9+
import sys
10+
import types
911
import uuid
10-
from typing import Any
12+
from typing import Annotated, Any, Union, cast, get_args, get_origin, get_type_hints
1113

1214
from mcp.server.fastmcp import FastMCP, Context as MCPContext
1315
from mcp.types import ToolAnnotations
@@ -23,13 +25,55 @@
2325
_APPROVAL_WAIT_MESSAGE = "Waiting for reviewer approval…"
2426

2527

26-
def _ensure_schwab_context(func: ToolFn) -> ToolFn:
28+
def _is_context_annotation(annotation: Any) -> bool:
29+
if annotation in (inspect._empty, None):
30+
return False
31+
if annotation is SchwabContext:
32+
return True
33+
if annotation == "SchwabContext":
34+
return True
35+
if isinstance(annotation, str):
36+
return annotation == "SchwabContext"
37+
38+
origin = get_origin(annotation)
39+
if origin is None:
40+
return False
41+
42+
if origin in (Annotated,):
43+
args = get_args(annotation)
44+
return bool(args) and _is_context_annotation(args[0])
45+
46+
if origin in (Union, types.UnionType):
47+
return any(_is_context_annotation(arg) for arg in get_args(annotation))
48+
49+
return False
50+
51+
52+
def _resolve_context_parameters(func: ToolFn) -> tuple[inspect.Signature, list[str]]:
2753
signature = inspect.signature(func)
28-
ctx_params = [
29-
name
30-
for name, param in signature.parameters.items()
31-
if param.annotation is SchwabContext
32-
]
54+
55+
module = sys.modules.get(func.__module__)
56+
globalns = vars(module) if module else {}
57+
58+
type_hints: dict[str, Any]
59+
try:
60+
type_hints = get_type_hints(func, globalns=globalns, include_extras=True)
61+
except TypeError:
62+
type_hints = get_type_hints(func, globalns=globalns)
63+
except Exception:
64+
type_hints = {}
65+
66+
ctx_params = []
67+
for name, param in signature.parameters.items():
68+
annotation = type_hints.get(name, param.annotation)
69+
if _is_context_annotation(annotation):
70+
ctx_params.append(name)
71+
72+
return signature, ctx_params
73+
74+
75+
def _ensure_schwab_context(func: ToolFn) -> ToolFn:
76+
signature, ctx_params = _resolve_context_parameters(func)
3377
if not ctx_params:
3478
return func
3579

@@ -57,6 +101,15 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
57101
return await result
58102
return result
59103

104+
# Ensure annotations referencing names from the original module remain resolvable.
105+
wrapper_globals = cast(dict[str, Any], getattr(wrapper, "__globals__", {}))
106+
module = inspect.getmodule(func)
107+
if module is not None:
108+
module_globals = vars(module)
109+
if wrapper_globals is not module_globals:
110+
for key, value in module_globals.items():
111+
wrapper_globals.setdefault(key, value)
112+
60113
return wrapper
61114

62115

@@ -68,12 +121,7 @@ def _format_argument(value: Any) -> str:
68121

69122

70123
def _wrap_with_approval(func: ToolFn) -> ToolFn:
71-
signature = inspect.signature(func)
72-
ctx_params = [
73-
name
74-
for name, param in signature.parameters.items()
75-
if param.annotation is SchwabContext
76-
]
124+
signature, ctx_params = _resolve_context_parameters(func)
77125
if not ctx_params:
78126
raise TypeError(
79127
f"Write tool '{func.__name__}' must accept a SchwabContext parameter for approval gating."
@@ -160,6 +208,14 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
160208
raise PermissionError(message)
161209
raise TimeoutError(message)
162210

211+
wrapper_globals = cast(dict[str, Any], getattr(wrapper, "__globals__", {}))
212+
module = inspect.getmodule(func)
213+
if module is not None:
214+
module_globals = vars(module)
215+
if wrapper_globals is not module_globals:
216+
for key, value in module_globals.items():
217+
wrapper_globals.setdefault(key, value)
218+
163219
return wrapper
164220

165221

src/schwab_mcp/tools/options.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#
2-
32
from typing import Annotated
43

54
import datetime
@@ -10,6 +9,9 @@
109
from schwab_mcp.tools.utils import JSONType, call
1110

1211

12+
_EXPIRATION_WINDOW_DAYS = 60
13+
14+
1315
def _parse_date(value: str | datetime.date | None) -> datetime.date | None:
1416
if value is None:
1517
return None
@@ -20,6 +22,29 @@ def _parse_date(value: str | datetime.date | None) -> datetime.date | None:
2022
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
2123

2224

25+
def _normalize_expiration_window(
26+
from_date: datetime.date | None,
27+
to_date: datetime.date | None,
28+
*,
29+
today: datetime.date | None = None,
30+
) -> tuple[datetime.date | None, datetime.date | None]:
31+
if from_date is None and to_date is None:
32+
today = datetime.date.today() if today is None else today
33+
return today, today + datetime.timedelta(days=_EXPIRATION_WINDOW_DAYS)
34+
35+
if from_date is None and to_date is not None:
36+
today = datetime.date.today() if today is None else today
37+
from_date = min(today, to_date)
38+
39+
if from_date is not None and to_date is None:
40+
to_date = from_date + datetime.timedelta(days=_EXPIRATION_WINDOW_DAYS)
41+
42+
if from_date is not None and to_date is not None and to_date < from_date:
43+
to_date = from_date
44+
45+
return from_date, to_date
46+
47+
2348
async def get_option_chain(
2449
ctx: SchwabContext,
2550
symbol: Annotated[str, "Symbol of the underlying security (e.g., 'AAPL', 'SPY')"],
@@ -45,12 +70,14 @@ async def get_option_chain(
4570
"""
4671
Returns option chain data (strikes, expirations, prices) for a symbol. Use for standard chains.
4772
Params: symbol, contract_type (CALL/PUT/ALL), strike_count (default 25), include_quotes (bool), from_date (YYYY-MM-DD), to_date (YYYY-MM-DD).
48-
Limit data returned using strike_count and date parameters.
73+
Limit data returned using strike_count and date parameters. When both dates are omitted the tool defaults to the next 60 calendar days to avoid oversized responses.
4974
"""
5075
client = ctx.options
5176

52-
from_date_obj = _parse_date(from_date)
53-
to_date_obj = _parse_date(to_date)
77+
from_date_obj, to_date_obj = _normalize_expiration_window(
78+
_parse_date(from_date),
79+
_parse_date(to_date),
80+
)
5481

5582
return await call(
5683
client.get_option_chain,
@@ -119,12 +146,16 @@ async def get_advanced_option_chain(
119146
"""
120147
Returns advanced option chain data with strategies, filters, and theoretical calculations. Use for complex analysis.
121148
Params: symbol, contract_type, strike_count, include_quotes, strategy (SINGLE/ANALYTICAL/etc.), interval, strike, strike_range (ITM/NTM/etc.), from/to_date, volatility/underlying_price/interest_rate/days_to_expiration (for ANALYTICAL), exp_month, option_type (STANDARD/NON_STANDARD/ALL).
122-
Limit data returned using strike_count and date parameters.
149+
Limit data returned using strike_count and date parameters. When both dates are omitted the tool defaults to the next 60 calendar days to avoid oversized responses.
123150
"""
124151
client = ctx.options
125152

126153
from_date_obj = _parse_date(from_date)
127154
to_date_obj = _parse_date(to_date)
155+
from_date_obj, to_date_obj = _normalize_expiration_window(
156+
from_date_obj,
157+
to_date_obj,
158+
)
128159

129160
return await call(
130161
client.get_option_chain,

0 commit comments

Comments
 (0)