diff --git a/src/infrahub_mcp/constants.py b/src/infrahub_mcp/constants.py index 01dde86..99c50da 100644 --- a/src/infrahub_mcp/constants.py +++ b/src/infrahub_mcp/constants.py @@ -1,3 +1,5 @@ +from __future__ import annotations + NAMESPACES_INTERNAL = ["Internal", "Profile", "Template"] schema_attribute_type_mapping = { diff --git a/src/infrahub_mcp/prompts/__init__.py b/src/infrahub_mcp/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/infrahub_mcp/prompts/main.md b/src/infrahub_mcp/prompts/main.md deleted file mode 100644 index 7218907..0000000 --- a/src/infrahub_mcp/prompts/main.md +++ /dev/null @@ -1,8 +0,0 @@ -You are an infrastructure specilist specialized in answering questions about the infrastructure. - -All the information you need are present in Infrahub and you can access it via an MCP server which exposes a number of tools. - -When someone ask about a specific data, you need to: -- Identify what is the associated kind in the schema for this data using the tool `schema_get_mapping` -- Retrieve more information about this specific model, including the option available to filter (tool : `get_node_filters`) -- Use the tool `get_objects` to query one or multiple objects \ No newline at end of file diff --git a/src/infrahub_mcp/prompts/usage_guide.py b/src/infrahub_mcp/prompts/usage_guide.py new file mode 100644 index 0000000..be0dbf4 --- /dev/null +++ b/src/infrahub_mcp/prompts/usage_guide.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Annotated + +from fastmcp import FastMCP +from mcp.types import PromptMessage, TextContent +from pydantic import Field + +mcp = FastMCP(name="Infrahub Usage Prompts", version="1.0.0") + + +@mcp.prompt( + name="answer_infra_question", + description="Answer infra questions by reading branch-scoped schema resources first, then using mapping+filters+query tools.", + tags={"schema", "query", "infra"}, +) +def answer_infra_question( + question: Annotated[str, Field(description="User question about the infra")], + kind_hint: Annotated[str | None, Field(default=None, description="Optional guess/hint for the schema kind")], + fields: Annotated[list[str] | None, Field(default=None, description="Optional list of fields to return")], + branch: Annotated[ + str | None, + Field(default=None, description="Branch to retrieve the objects from. Defaults to None (uses default branch)."), + ], +) -> PromptMessage: + # Resolve URIs (supports your 'current' static resource if you added it; otherwise set a default branch) + resolved_branch = branch or "current" + base_uri = f"infrahub://branch/{resolved_branch}/schema" + kind_uri = f"infrahub://branch/{resolved_branch}/schema/{{target_kind}}" + fields_display = fields if fields is not None else [] + + txt = f""" +You are an infrastructure specialist. + +User question: {question} + +PIPELINE (follow exactly): + +STEP 0 — Read schema (resources) +- Read: {base_uri} +- After choosing a kind, also read: {kind_uri} +- If any resource returns status="error", use its remediation and stop. + +STEP 1 — Identify target kind +- Try tool `schema_get_mapping(question)` to map the question to a kind. +- If that fails, infer from the schema resource (names/attributes/relationships). +- Use kind_hint if helpful: "{kind_hint or ""}". Decide a single target_kind and note any assumptions briefly. + +STEP 2 — Validate fields +- From the kind schema, validate requested fields: {fields_display or "[]"} +- If a requested field is missing, pick the closest valid field and note the substitution. + +STEP 3 — Build filters +- Call tool `get_node_filters(kind=target_kind{f", branch={branch!r}" if branch else ""})` to learn valid filters. +- Translate natural-language constraints in the question into the tool's filter parameters. + +STEP 4 — Query data +- Call tool `get_objects(kind=target_kind, fields=, filters={f", branch={branch!r}" if branch else ""})`. +- If status="error", surface remediation and stop. + +STEP 5 — Answer + provenance +- Provide a concise answer (bullets are fine). +- Add a short "Provenance" listing: + - resources read (URIs), + - tools called (name + key args), + - assumptions/substitutions. + +RULES +- Read resources before calling tools. +- Only use fields present in the kind schema. +- Keep outputs concise and relevant. +""" + return PromptMessage(role="user", content=TextContent(type="text", text=txt)) diff --git a/src/infrahub_mcp/resources/schema.py b/src/infrahub_mcp/resources/schema.py new file mode 100644 index 0000000..05cf367 --- /dev/null +++ b/src/infrahub_mcp/resources/schema.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated + +from fastmcp import Context, FastMCP +from pydantic import Field + +from infrahub_mcp.utils import MCPResponse, MCPToolStatus + +if TYPE_CHECKING: + from infrahub_sdk.client import InfrahubClient + +mcp = FastMCP(name="Infrahub Schema Resources", version="1.0.0") + + +@mcp.resource( + uri="infrahub://branch/{branch}/schema", + description="All available schema kinds and their attributes for a given branch.", + mime_type="application/json", + # TODO: Add audience and priorities + # annotations=Annotations(), + tags={"schema"}, +) +async def schema_all( + ctx: Context, + branch: Annotated[ + str | None, + Field(default=None, description="Branch name to read schema. Defaults to the default branch if not specified."), + ], +) -> dict: + """Return the complete schema catalog for a branch. + + Parameters + branch: Branch name to read schema from. + + Returns + MCPResponse with success status and objects. + """ + client: InfrahubClient = ctx.request_context.lifespan_context.client + + try: + data = await client.schema.all(branch=branch) + return MCPResponse( + status=MCPToolStatus.SUCCESS, data=data + ).model_dump() # Convert Pydantic models to dicts for JSON serialization + + # FIXME: Be more specific with exception handling once SDK exceptions are defined + except Exception as exc: # noqa: BLE001 + return MCPResponse( + status=MCPToolStatus.ERROR, + error=f"Failed to fetch schema catalog for branch '{branch}': {type(exc).__name__}", + remediation="Verify the branch exists and your credentials/SDK connectivity are valid.", + ).model_dump() # Convert Pydantic models to dicts for JSON serialization + + +@mcp.resource( + uri="infrahub://branch/{branch}/schema/{kind}", + description="Schema for a specific kind within a given branch.", + mime_type="application/json", + # TODO: Add audience and priorities + # annotations=Annotations(), + tags={"schema"}, +) +async def schema_by_kind( + ctx: Context, + branch: str = Field(description="Branch name to read schema from"), + kind: str = Field(description="Node kind to fetch, e.g. 'Device'"), +) -> dict: + """Return a single kind's schema for a branch. + + Parameters + branch: Branch name to read schema from. + kind: Kind of the schema to retrieve. + + Returns + MCPResponse with success status and objects. + """ + if not kind or not kind.strip(): + return MCPResponse( + status=MCPToolStatus.ERROR, + error="Parameter 'kind' must be a non-empty string.", + remediation="Provide a valid kind, e.g. 'Device' or 'Site'.", + ).model_dump() # Convert Pydantic models to dicts for JSON serialization + + client: InfrahubClient = ctx.request_context.lifespan_context.client + try: + data = await client.schema.get(kind=kind, branch=branch) + if not data: + msg = f"Schema kind '{kind}' not found in branch '{branch}'." + remediation = ( + f"List available kinds via resource 'infrahub://branch/{branch}/schema' and pick an existing kind." + ) + return MCPResponse( + status=MCPToolStatus.ERROR, + error=msg, + remediation=remediation, + ).model_dump() # Convert Pydantic models to dicts for JSON serialization + return MCPResponse( + status=MCPToolStatus.SUCCESS, data=data + ).model_dump() # Convert Pydantic models to dicts for JSON serialization + # FIXME: Be more specific with exception handling once SDK exceptions are defined + except Exception as exc: # noqa: BLE001 + msg = f"Failed to fetch schema kind '{kind}' in branch '{branch}': {type(exc).__name__}" + return MCPResponse( + status=MCPToolStatus.ERROR, + error=msg, + remediation="Confirm the branch exists and the kind name is correct; check server/SDK logs for details.", + ).model_dump() # Convert Pydantic models to dicts for JSON serialization diff --git a/src/infrahub_mcp/server.py b/src/infrahub_mcp/server.py index fcde82f..401f076 100644 --- a/src/infrahub_mcp/server.py +++ b/src/infrahub_mcp/server.py @@ -1,14 +1,21 @@ -from collections.abc import AsyncIterator +from __future__ import annotations + from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import TYPE_CHECKING from fastmcp import FastMCP from infrahub_sdk.client import InfrahubClient -from infrahub_mcp.tools.branch import mcp as branch_mcp -from infrahub_mcp.tools.gql import mcp as graphql_mcp -from infrahub_mcp.tools.nodes import mcp as nodes_mcp -from infrahub_mcp.tools.schema import mcp as schema_mcp +from infrahub_mcp.prompts.usage_guide import mcp as prompt_usage_guide_mcp +from infrahub_mcp.resources.schema import mcp as resource_schema_mcp +from infrahub_mcp.tools.branch import mcp as tool_branch_mcp +from infrahub_mcp.tools.gql import mcp as tool_graphql_mcp +from infrahub_mcp.tools.nodes import mcp as tool_nodes_mcp +from infrahub_mcp.tools.schema import mcp as tool_schema_mcp + +if TYPE_CHECKING: + from collections.abc import AsyncIterator @dataclass @@ -29,7 +36,15 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # noqa: A mcp: FastMCP = FastMCP(name="Infrahub MCP Server", version="0.1.0", lifespan=app_lifespan) # Mount the various MCPs to the main server -mcp.mount(branch_mcp) -mcp.mount(graphql_mcp) -mcp.mount(nodes_mcp) -mcp.mount(schema_mcp) + +# Resources +mcp.mount(resource_schema_mcp) + +# Prompts +mcp.mount(prompt_usage_guide_mcp) + +# Tools +mcp.mount(tool_branch_mcp) +mcp.mount(tool_graphql_mcp) +mcp.mount(tool_nodes_mcp) +mcp.mount(tool_schema_mcp) diff --git a/src/infrahub_mcp/tools/branch.py b/src/infrahub_mcp/tools/branch.py index 5688686..403cd4b 100644 --- a/src/infrahub_mcp/tools/branch.py +++ b/src/infrahub_mcp/tools/branch.py @@ -1,7 +1,8 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Annotated from fastmcp import Context, FastMCP -from infrahub_sdk.branch import BranchData from infrahub_sdk.exceptions import GraphQLError from mcp.types import ToolAnnotations from pydantic import Field @@ -10,6 +11,7 @@ if TYPE_CHECKING: from infrahub_sdk import InfrahubClient + from infrahub_sdk.branch import BranchData mcp: FastMCP = FastMCP(name="Infrahub Branches") diff --git a/src/infrahub_mcp/tools/gql.py b/src/infrahub_mcp/tools/gql.py index 33b5f07..7b5489d 100644 --- a/src/infrahub_mcp/tools/gql.py +++ b/src/infrahub_mcp/tools/gql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Annotated, Any from fastmcp import Context, FastMCP diff --git a/src/infrahub_mcp/tools/nodes.py b/src/infrahub_mcp/tools/nodes.py index a7104e6..adc7f59 100644 --- a/src/infrahub_mcp/tools/nodes.py +++ b/src/infrahub_mcp/tools/nodes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Annotated, Any from fastmcp import Context, FastMCP @@ -15,8 +17,11 @@ mcp: FastMCP = FastMCP(name="Infrahub Nodes") -@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) -async def get_nodes( +# FIXME: deactivate for now until we figure out what is the issue with the filters +@mcp.tool( + tags={"nodes", "retrieve"}, enabled=False, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True) +) +async def get_nodes( # noqa: PLR0913, PLR0917 ctx: Context, kind: Annotated[str, Field(description="Kind of the objects to retrieve.")], branch: Annotated[ @@ -25,6 +30,7 @@ async def get_nodes( ], filters: Annotated[dict[str, Any] | None, Field(default=None, description="Dictionary of filters to apply.")], partial_match: Annotated[bool, Field(default=False, description="Whether to use partial matching for filters.")], + limit: Annotated[int, Field(default=100, description="Maximum number of objects to retrieve. Defaults to 100.")], ) -> MCPResponse: """Get all objects of a specific kind from Infrahub. @@ -36,6 +42,7 @@ async def get_nodes( branch: Branch to retrieve the objects from. Defaults to None (uses default branch). filters: Dictionary of filters to apply. partial_match: Whether to use partial matching for filters. + limit: Maximum number of objects to retrieve. Defaults to 100. Returns: MCPResponse with success status and objects. @@ -65,6 +72,7 @@ async def get_nodes( order=Order(disable=True), populate_store=True, prefetch_relationships=True, + limit=limit, **filters, ) else: @@ -75,18 +83,16 @@ async def get_nodes( order=Order(disable=True), populate_store=True, prefetch_relationships=True, + limit=limit, ) except GraphQLError as exc: - return await _log_and_return_error(ctx=ctx, error=exc, remediation="Check the provided filters or the kind name.") + return await _log_and_return_error( + ctx=ctx, error=exc, remediation="Check the provided filters or the kind name." + ) - # Format the response with serializable data - # serialized_nodes = [] - # for node in nodes: - # node_data = await convert_node_to_dict(obj=node, branch=branch) - # serialized_nodes.append(node_data) - serialized_nodes = [obj.display_label for obj in nodes] + # Format the response with serializable data using convert_node_to_dict + serialized_nodes = [await convert_node_to_dict(obj=node, branch=branch) for node in nodes] - # Return the serialized response await ctx.debug(f"Retrieved {len(serialized_nodes)} nodes of kind {kind}") return MCPResponse( @@ -95,7 +101,7 @@ async def get_nodes( ) -@mcp.tool(tags={"nodes", "filters", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) +@mcp.tool(tags={"nodes", "filters", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)) async def get_node_filters( ctx: Context, kind: Annotated[str, Field(description="Kind of the objects to retrieve.")], @@ -152,7 +158,7 @@ async def get_node_filters( ) -@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) +@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)) async def get_related_nodes( ctx: Context, kind: Annotated[str, Field(description="Kind of the objects to retrieve.")], @@ -184,6 +190,7 @@ async def get_related_nodes( try: node_id = node_hfid = None + node = None if filters.get("ids"): node_id = filters["ids"][0] elif filters.get("hfid"): @@ -209,6 +216,12 @@ async def get_related_nodes( except Exception as exc: # noqa: BLE001 return await _log_and_return_error(ctx=ctx, error=exc) + if not node: + return await _log_and_return_error( + ctx=ctx, + error=f"No {kind} found with provided filters: {filters}", + remediation="Verify the kind and filters are correct. Use the `get_node_filters` tool to list available filters.", + ) rel = getattr(node, relation, None) if not rel: return await _log_and_return_error( diff --git a/src/infrahub_mcp/tools/schema.py b/src/infrahub_mcp/tools/schema.py index a9ff4c2..0602501 100644 --- a/src/infrahub_mcp/tools/schema.py +++ b/src/infrahub_mcp/tools/schema.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Annotated, Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated from fastmcp import Context, FastMCP from infrahub_sdk.exceptions import BranchNotFoundError, SchemaNotFoundError @@ -14,7 +16,7 @@ mcp: FastMCP = FastMCP(name="Infrahub Schemas") -@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) +@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)) async def get_schema_mapping( ctx: Context, branch: Annotated[ @@ -52,7 +54,7 @@ async def get_schema_mapping( ) -@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) +@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)) async def get_schema( ctx: Context, kind: Annotated[str, Field(description="Schema Kind to retrieve.")], @@ -91,7 +93,7 @@ async def get_schema( ) -@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True)) +@mcp.tool(tags={"schemas", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)) async def get_schemas( ctx: Context, branch: Annotated[ diff --git a/src/infrahub_mcp/utils.py b/src/infrahub_mcp/utils.py index 3c72f81..35b38d4 100644 --- a/src/infrahub_mcp/utils.py +++ b/src/infrahub_mcp/utils.py @@ -1,11 +1,15 @@ +from __future__ import annotations + from enum import Enum from pathlib import Path -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar -from fastmcp import Context from infrahub_sdk.node import Attribute, InfrahubNode, RelatedNode, RelationshipManager from pydantic import BaseModel +if TYPE_CHECKING: + from fastmcp import Context + CURRENT_DIRECTORY = Path(__file__).parent.resolve() PROMPTS_DIRECTORY = CURRENT_DIRECTORY / "prompts" @@ -25,13 +29,6 @@ class MCPResponse[T](BaseModel): remediation: str | None = None -def get_prompt(name: str) -> str: - prompt_file = PROMPTS_DIRECTORY / f"{name}.md" - if not prompt_file.exists(): - raise FileNotFoundError(f"Prompt file '{prompt_file}' does not exist.") - return (PROMPTS_DIRECTORY / f"{name}.md").read_text() - - async def _log_and_return_error(ctx: Context, error: str | Exception, remediation: str | None = None) -> MCPResponse: """Log an error and return a standardized error response.""" if isinstance(error, Exception):