Skip to content

Commit 7664954

Browse files
add tests, clean out old mcps
1 parent dfab523 commit 7664954

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+373
-1172
lines changed

dev/generate_mcp_tools.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Define utilities for (mostly) auto-generating MCP tools.
2+
3+
This file will autogenerate a (Fast)MCP set of tools with
4+
type annotations.
5+
6+
The resultant tools are perhaps too general for use in an MCP.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import TYPE_CHECKING
12+
13+
from mp_api.client import MPRester
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Callable
17+
from pathlib import Path
18+
19+
20+
def get_annotation_signature(
21+
obj: Callable, tablen: int = 4
22+
) -> tuple[str | None, str | None]:
23+
"""Reconstruct the type annotations associated with a Callable.
24+
25+
Returns the type annotations on input, and the output
26+
kwargs as str if type annotations can be inferred.
27+
"""
28+
kwargs = None
29+
out_kwargs = None
30+
if (annos := obj.__annotations__) and (defaults := obj.__defaults__):
31+
non_ret_type = [k for k in annos if k != "return"]
32+
defaults = [f" = {val}" for val in defaults]
33+
if len(defaults) < len(non_ret_type):
34+
defaults = [""] * (len(non_ret_type) - len(defaults)) + defaults
35+
kwargs = ",\n".join(
36+
f"{' '*tablen}{k} : {v}{defaults[i]}"
37+
for i, (k, v) in enumerate(annos.items())
38+
if k != "return"
39+
)
40+
out_kwargs = ",\n".join(
41+
f"{' '*2*tablen}{k} = {k}" for k in annos if k != "return"
42+
)
43+
return kwargs, out_kwargs
44+
45+
46+
def regenerate_tools(
47+
client: MPRester | None = None, file_name: str | Path | None = None
48+
) -> str:
49+
"""Utility to regenerate the informative tool names with annotations."""
50+
func_str = """# ruff: noqa
51+
from __future__ import annotations
52+
53+
from datetime import datetime
54+
from typing import Literal
55+
56+
from emmet.core.chemenv import (
57+
COORDINATION_GEOMETRIES,
58+
COORDINATION_GEOMETRIES_IUCR,
59+
COORDINATION_GEOMETRIES_IUPAC,
60+
COORDINATION_GEOMETRIES_NAMES,
61+
)
62+
from emmet.core.electronic_structure import BSPathType, DOSProjectionType
63+
from emmet.core.grain_boundary import GBTypeEnum
64+
from emmet.core.mpid import MPID
65+
from emmet.core.thermo import ThermoType
66+
from emmet.core.summary import HasProps
67+
from emmet.core.symmetry import CrystalSystem
68+
from emmet.core.synthesis import SynthesisTypeEnum, OperationTypeEnum
69+
from emmet.core.vasp.calc_types import CalcType
70+
from emmet.core.xas import Edge, Type
71+
72+
from pymatgen.analysis.magnetism.analyzer import Ordering
73+
from pymatgen.core.periodic_table import Element
74+
from pymatgen.core.structure import Structure
75+
from pymatgen.electronic_structure.core import OrbitalType, Spin
76+
77+
"""
78+
79+
translate = {
80+
"chemenv": "chemical_environment",
81+
"dos": "density_of_states",
82+
"eos": "equation_of_state",
83+
"summary": "material",
84+
"robocrys": "crystal_summary",
85+
}
86+
87+
mp_client = client or MPRester()
88+
89+
def _get_rester_sub_name(name, route) -> str | None:
90+
for y in [x for x in dir(route) if not x.startswith("_")]:
91+
attr = getattr(route, y)
92+
if (
93+
(hasattr(attr, "__name__") and attr.__name__ == name)
94+
or (hasattr(attr, "__class__"))
95+
and attr.__class__.__name__ == name
96+
):
97+
return y
98+
return None
99+
100+
for x in mp_client._all_resters:
101+
if not (
102+
sub_rest_route := _get_rester_sub_name(x.__name__, mp_client.materials)
103+
):
104+
continue
105+
106+
search_method = "search"
107+
if "robocrys" in x.__name__.lower():
108+
search_method = "search_docs"
109+
110+
informed_name = sub_rest_route
111+
for k, v in translate.items():
112+
if k in informed_name:
113+
informed_name = informed_name.replace(k, v)
114+
115+
kwargs, out_kwargs = get_annotation_signature(getattr(x, search_method))
116+
if not kwargs:
117+
# FastMCP raises a ValueError if types are not provided:
118+
# `Functions with **kwargs are not supported as tools`
119+
continue
120+
func_str += (
121+
f"def get_{informed_name}_data(\n"
122+
f" self,\n{kwargs}\n) -> list[dict]:\n"
123+
f" return self.client.materials.{sub_rest_route}"
124+
f".search(\n{out_kwargs}\n)\n\n"
125+
)
126+
127+
helpers = [
128+
method
129+
for method in dir(mp_client)
130+
if any(
131+
method.startswith(signature)
132+
for signature in (
133+
"get",
134+
"find",
135+
)
136+
)
137+
]
138+
for func_name in helpers:
139+
func = getattr(mp_client, func_name)
140+
# MCP doesn't work with LRU cached functions?
141+
if hasattr(func, "cache_info"):
142+
continue
143+
144+
kwargs, out_kwargs = get_annotation_signature(func)
145+
if not kwargs:
146+
continue
147+
148+
informed_name = func_name.replace("find", "get")
149+
for k, v in translate.items():
150+
if k in informed_name:
151+
informed_name = informed_name.replace(k, v)
152+
153+
func_str += (
154+
f"def {informed_name}(\n"
155+
f" self,\n{kwargs}\n) -> list[dict]:\n"
156+
f" return self.client.{func_name}(\n"
157+
f"{out_kwargs}\n)\n\n"
158+
)
159+
160+
if file_name:
161+
with open(file_name, "w") as f:
162+
f.write(func_str)
163+
164+
return func_str

mp_api/mcp/mp_mcp.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
"""Define custom MCP tools for the Materials Project API."""
22
from __future__ import annotations
33

4-
from urllib.parse import urljoin
5-
6-
import httpx
74
from fastmcp import FastMCP
85

9-
from mp_api.mcp.tools import MPCoreMCP, MPMcpTools
10-
from mp_api.mcp.utils import _NeedsMPClient
6+
from mp_api.mcp.tools import MPCoreMCP
117

128
MCP_SERVER_INSTRUCTIONS = """
139
This MCP server defines search and document retrieval capabilities
@@ -28,29 +24,3 @@ def get_core_mcp() -> FastMCP:
2824
for k in {"search", "fetch"}:
2925
mp_mcp.tool(getattr(core_tools, k), name=k)
3026
return mp_mcp
31-
32-
33-
def get_mcp() -> FastMCP:
34-
"""MCP with finer depth of control over tool names."""
35-
mp_mcp = FastMCP("Materials_Project_MCP")
36-
mcp_tools = MPMcpTools()
37-
for attr in {x for x in dir(mcp_tools) if x.startswith("get_")}:
38-
mp_mcp.tool(getattr(mcp_tools, attr))
39-
40-
# Register tool to set the user's API key
41-
mp_mcp.tool(mcp_tools.update_user_api_key)
42-
return mp_mcp
43-
44-
45-
def get_bootstrap_mcp() -> FastMCP:
46-
"""Bootstrap an MP API MCP only from the OpenAPI spec."""
47-
client = _NeedsMPClient().client
48-
49-
return FastMCP.from_openapi(
50-
openapi_spec=httpx.get(urljoin(client.endpoint, "openapi.json")).json(),
51-
client=httpx.AsyncClient(
52-
base_url=client.endpoint,
53-
headers={"x-api-key": client.api_key},
54-
),
55-
name="MP_OpenAPI_MCP",
56-
)

mp_api/mcp/server.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,75 @@
1-
"""Run MCP."""
1+
"""Configure the Materials Project MCP server."""
22
from __future__ import annotations
33

4-
from mp_api.mcp.mp_mcp import get_core_mcp
4+
from argparse import ArgumentParser
5+
from typing import TYPE_CHECKING, get_args
6+
7+
from fastmcp import FastMCP
8+
from fastmcp.server.server import Transport
9+
10+
from mp_api.client.core.exceptions import MPRestError
11+
from mp_api.mcp.tools import MPCoreMCP
12+
13+
if TYPE_CHECKING:
14+
from collections.abc import Sequence
15+
from typing import Any
16+
17+
MCP_SERVER_INSTRUCTIONS = """
18+
This MCP server defines search and document retrieval capabilities
19+
for data in the Materials Project.
20+
Use the search tool to find relevant documents based on materials
21+
keywords.
22+
Then use the fetch tool to retrieve complete materials summary information.
23+
"""
24+
25+
26+
def get_core_mcp() -> FastMCP:
27+
"""Create an MCP compatible with OpenAI models."""
28+
mp_mcp = FastMCP(
29+
"Materials_Project_MCP",
30+
instructions=MCP_SERVER_INSTRUCTIONS,
31+
)
32+
core_tools = MPCoreMCP()
33+
for k in {"search", "fetch"}:
34+
mp_mcp.tool(getattr(core_tools, k), name=k)
35+
return mp_mcp
36+
37+
38+
def parse_server_args(args: Sequence[str] | None = None) -> dict[str, Any]:
39+
"""Parse CLI arguments for server configuration."""
40+
server_config = {"transport", "host", "port"}
41+
transport_vals = get_args(Transport)
42+
43+
arg_parser = ArgumentParser()
44+
arg_parser.add_argument(
45+
"--transport",
46+
type=str,
47+
required=False,
48+
)
49+
arg_parser.add_argument(
50+
"--host",
51+
type=str,
52+
required=False,
53+
)
54+
arg_parser.add_argument(
55+
"--port",
56+
type=int,
57+
required=False,
58+
)
59+
60+
parsed_args = arg_parser.parse_args(args=args)
61+
kwargs = {}
62+
for k in server_config:
63+
if (v := getattr(parsed_args, k, None)) is not None:
64+
if k == "transport" and v not in transport_vals:
65+
raise MPRestError(
66+
f"Invalid `transport={v}`, choose one of: {', '.join(transport_vals)}"
67+
)
68+
kwargs[k] = v
69+
return kwargs
70+
571

672
mcp = get_core_mcp()
773

874
if __name__ == "__main__":
9-
mcp.run()
75+
mcp.run(**parse_server_args())

0 commit comments

Comments
 (0)