Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/infrahub_mcp/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

NAMESPACES_INTERNAL = ["Internal", "Profile", "Template"]

schema_attribute_type_mapping = {
Expand Down
Empty file.
8 changes: 0 additions & 8 deletions src/infrahub_mcp/prompts/main.md

This file was deleted.

73 changes: 73 additions & 0 deletions src/infrahub_mcp/prompts/usage_guide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from typing import Annotated

from fastmcp import FastMCP
from mcp.types import PromptMessage, TextContent
from pydantic import Field

mcp = FastMCP(name="Infrahub Usage Prompts", version="1.0.0")


@mcp.prompt(
name="answer_infra_question",
description="Answer infra questions by reading branch-scoped schema resources first, then using mapping+filters+query tools.",
tags={"schema", "query", "infra"},
)
def answer_infra_question(
question: Annotated[str, Field(description="User question about the infra")],
kind_hint: Annotated[str | None, Field(default=None, description="Optional guess/hint for the schema kind")],
fields: Annotated[list[str] | None, Field(default=None, description="Optional list of fields to return")],
branch: Annotated[
str | None,
Field(default=None, description="Branch to retrieve the objects from. Defaults to None (uses default branch)."),
],
) -> PromptMessage:
# Resolve URIs (supports your 'current' static resource if you added it; otherwise set a default branch)
resolved_branch = branch or "current"
base_uri = f"infrahub://branch/{resolved_branch}/schema"
kind_uri = f"infrahub://branch/{resolved_branch}/schema/{{target_kind}}"
fields_display = fields if fields is not None else []

txt = f"""
You are an infrastructure specialist.

User question: {question}

PIPELINE (follow exactly):

STEP 0 — Read schema (resources)
- Read: {base_uri}
- After choosing a kind, also read: {kind_uri}
- If any resource returns status="error", use its remediation and stop.

STEP 1 — Identify target kind
- Try tool `schema_get_mapping(question)` to map the question to a kind.
- If that fails, infer from the schema resource (names/attributes/relationships).
- Use kind_hint if helpful: "{kind_hint or ""}". Decide a single target_kind and note any assumptions briefly.

STEP 2 — Validate fields
- From the kind schema, validate requested fields: {fields_display or "[]"}
- If a requested field is missing, pick the closest valid field and note the substitution.

STEP 3 — Build filters
- Call tool `get_node_filters(kind=target_kind{f", branch={branch!r}" if branch else ""})` to learn valid filters.
- Translate natural-language constraints in the question into the tool's filter parameters.

STEP 4 — Query data
- Call tool `get_objects(kind=target_kind, fields=<validated fields>, filters=<built filters>{f", branch={branch!r}" if branch else ""})`.
- If status="error", surface remediation and stop.

STEP 5 — Answer + provenance
- Provide a concise answer (bullets are fine).
- Add a short "Provenance" listing:
- resources read (URIs),
- tools called (name + key args),
- assumptions/substitutions.

RULES
- Read resources before calling tools.
- Only use fields present in the kind schema.
- Keep outputs concise and relevant.
"""
return PromptMessage(role="user", content=TextContent(type="text", text=txt))
108 changes: 108 additions & 0 deletions src/infrahub_mcp/resources/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated

from fastmcp import Context, FastMCP
from pydantic import Field

from infrahub_mcp.utils import MCPResponse, MCPToolStatus

if TYPE_CHECKING:
from infrahub_sdk.client import InfrahubClient

mcp = FastMCP(name="Infrahub Schema Resources", version="1.0.0")


@mcp.resource(
uri="infrahub://branch/{branch}/schema",
description="All available schema kinds and their attributes for a given branch.",
mime_type="application/json",
# TODO: Add audience and priorities
# annotations=Annotations(),
tags={"schema"},
)
async def schema_all(
ctx: Context,
branch: Annotated[
str | None,
Field(default=None, description="Branch name to read schema. Defaults to the default branch if not specified."),
],
) -> dict:
"""Return the complete schema catalog for a branch.

Parameters
branch: Branch name to read schema from.

Returns
MCPResponse with success status and objects.
"""
client: InfrahubClient = ctx.request_context.lifespan_context.client

try:
data = await client.schema.all(branch=branch)
return MCPResponse(
status=MCPToolStatus.SUCCESS, data=data
).model_dump() # Convert Pydantic models to dicts for JSON serialization

# FIXME: Be more specific with exception handling once SDK exceptions are defined
except Exception as exc: # noqa: BLE001
return MCPResponse(
status=MCPToolStatus.ERROR,
error=f"Failed to fetch schema catalog for branch '{branch}': {type(exc).__name__}",
remediation="Verify the branch exists and your credentials/SDK connectivity are valid.",
).model_dump() # Convert Pydantic models to dicts for JSON serialization


@mcp.resource(
uri="infrahub://branch/{branch}/schema/{kind}",
description="Schema for a specific kind within a given branch.",
mime_type="application/json",
# TODO: Add audience and priorities
# annotations=Annotations(),
tags={"schema"},
)
async def schema_by_kind(
ctx: Context,
branch: str = Field(description="Branch name to read schema from"),
kind: str = Field(description="Node kind to fetch, e.g. 'Device'"),
) -> dict:
"""Return a single kind's schema for a branch.

Parameters
branch: Branch name to read schema from.
kind: Kind of the schema to retrieve.

Returns
MCPResponse with success status and objects.
"""
if not kind or not kind.strip():
return MCPResponse(
status=MCPToolStatus.ERROR,
error="Parameter 'kind' must be a non-empty string.",
remediation="Provide a valid kind, e.g. 'Device' or 'Site'.",
).model_dump() # Convert Pydantic models to dicts for JSON serialization

client: InfrahubClient = ctx.request_context.lifespan_context.client
try:
data = await client.schema.get(kind=kind, branch=branch)
if not data:
msg = f"Schema kind '{kind}' not found in branch '{branch}'."
remediation = (
f"List available kinds via resource 'infrahub://branch/{branch}/schema' and pick an existing kind."
)
return MCPResponse(
status=MCPToolStatus.ERROR,
error=msg,
remediation=remediation,
).model_dump() # Convert Pydantic models to dicts for JSON serialization
return MCPResponse(
status=MCPToolStatus.SUCCESS, data=data
).model_dump() # Convert Pydantic models to dicts for JSON serialization
# FIXME: Be more specific with exception handling once SDK exceptions are defined
except Exception as exc: # noqa: BLE001
msg = f"Failed to fetch schema kind '{kind}' in branch '{branch}': {type(exc).__name__}"
return MCPResponse(
status=MCPToolStatus.ERROR,
error=msg,
remediation="Confirm the branch exists and the kind name is correct; check server/SDK logs for details.",
).model_dump() # Convert Pydantic models to dicts for JSON serialization
33 changes: 24 additions & 9 deletions src/infrahub_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from collections.abc import AsyncIterator
from __future__ import annotations

from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING

from fastmcp import FastMCP
from infrahub_sdk.client import InfrahubClient

from infrahub_mcp.tools.branch import mcp as branch_mcp
from infrahub_mcp.tools.gql import mcp as graphql_mcp
from infrahub_mcp.tools.nodes import mcp as nodes_mcp
from infrahub_mcp.tools.schema import mcp as schema_mcp
from infrahub_mcp.prompts.usage_guide import mcp as prompt_usage_guide_mcp
from infrahub_mcp.resources.schema import mcp as resource_schema_mcp
from infrahub_mcp.tools.branch import mcp as tool_branch_mcp
from infrahub_mcp.tools.gql import mcp as tool_graphql_mcp
from infrahub_mcp.tools.nodes import mcp as tool_nodes_mcp
from infrahub_mcp.tools.schema import mcp as tool_schema_mcp

if TYPE_CHECKING:
from collections.abc import AsyncIterator


@dataclass
Expand All @@ -29,7 +36,15 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # noqa: A
mcp: FastMCP = FastMCP(name="Infrahub MCP Server", version="0.1.0", lifespan=app_lifespan)

# Mount the various MCPs to the main server
mcp.mount(branch_mcp)
mcp.mount(graphql_mcp)
mcp.mount(nodes_mcp)
mcp.mount(schema_mcp)

# Resources
mcp.mount(resource_schema_mcp)

# Prompts
mcp.mount(prompt_usage_guide_mcp)

# Tools
mcp.mount(tool_branch_mcp)
mcp.mount(tool_graphql_mcp)
mcp.mount(tool_nodes_mcp)
mcp.mount(tool_schema_mcp)
4 changes: 3 additions & 1 deletion src/infrahub_mcp/tools/branch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated

from fastmcp import Context, FastMCP
from infrahub_sdk.branch import BranchData
from infrahub_sdk.exceptions import GraphQLError
from mcp.types import ToolAnnotations
from pydantic import Field
Expand All @@ -10,6 +11,7 @@

if TYPE_CHECKING:
from infrahub_sdk import InfrahubClient
from infrahub_sdk.branch import BranchData

mcp: FastMCP = FastMCP(name="Infrahub Branches")

Expand Down
2 changes: 2 additions & 0 deletions src/infrahub_mcp/tools/gql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, Any

from fastmcp import Context, FastMCP
Expand Down
37 changes: 25 additions & 12 deletions src/infrahub_mcp/tools/nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, Any

from fastmcp import Context, FastMCP
Expand All @@ -15,8 +17,11 @@
mcp: FastMCP = FastMCP(name="Infrahub Nodes")


@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True))
async def get_nodes(
# FIXME: deactivate for now until we figure out what is the issue with the filters
@mcp.tool(
tags={"nodes", "retrieve"}, enabled=False, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True)
)
async def get_nodes( # noqa: PLR0913, PLR0917
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
branch: Annotated[
Expand All @@ -25,6 +30,7 @@ async def get_nodes(
],
filters: Annotated[dict[str, Any] | None, Field(default=None, description="Dictionary of filters to apply.")],
partial_match: Annotated[bool, Field(default=False, description="Whether to use partial matching for filters.")],
limit: Annotated[int, Field(default=100, description="Maximum number of objects to retrieve. Defaults to 100.")],
) -> MCPResponse:
"""Get all objects of a specific kind from Infrahub.

Expand All @@ -36,6 +42,7 @@ async def get_nodes(
branch: Branch to retrieve the objects from. Defaults to None (uses default branch).
filters: Dictionary of filters to apply.
partial_match: Whether to use partial matching for filters.
limit: Maximum number of objects to retrieve. Defaults to 100.

Returns:
MCPResponse with success status and objects.
Expand Down Expand Up @@ -65,6 +72,7 @@ async def get_nodes(
order=Order(disable=True),
populate_store=True,
prefetch_relationships=True,
limit=limit,
**filters,
)
else:
Expand All @@ -75,18 +83,16 @@ async def get_nodes(
order=Order(disable=True),
populate_store=True,
prefetch_relationships=True,
limit=limit,
)
except GraphQLError as exc:
return await _log_and_return_error(ctx=ctx, error=exc, remediation="Check the provided filters or the kind name.")
return await _log_and_return_error(
ctx=ctx, error=exc, remediation="Check the provided filters or the kind name."
)

# Format the response with serializable data
# serialized_nodes = []
# for node in nodes:
# node_data = await convert_node_to_dict(obj=node, branch=branch)
# serialized_nodes.append(node_data)
serialized_nodes = [obj.display_label for obj in nodes]
# Format the response with serializable data using convert_node_to_dict
serialized_nodes = [await convert_node_to_dict(obj=node, branch=branch) for node in nodes]

# Return the serialized response
await ctx.debug(f"Retrieved {len(serialized_nodes)} nodes of kind {kind}")

return MCPResponse(
Expand All @@ -95,7 +101,7 @@ async def get_nodes(
)


@mcp.tool(tags={"nodes", "filters", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True))
@mcp.tool(tags={"nodes", "filters", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True))
async def get_node_filters(
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
Expand Down Expand Up @@ -152,7 +158,7 @@ async def get_node_filters(
)


@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True))
@mcp.tool(tags={"nodes", "retrieve"}, annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True))
async def get_related_nodes(
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
Expand Down Expand Up @@ -184,6 +190,7 @@ async def get_related_nodes(

try:
node_id = node_hfid = None
node = None
if filters.get("ids"):
node_id = filters["ids"][0]
elif filters.get("hfid"):
Expand All @@ -209,6 +216,12 @@ async def get_related_nodes(
except Exception as exc: # noqa: BLE001
return await _log_and_return_error(ctx=ctx, error=exc)

if not node:
return await _log_and_return_error(
ctx=ctx,
error=f"No {kind} found with provided filters: {filters}",
remediation="Verify the kind and filters are correct. Use the `get_node_filters` tool to list available filters.",
)
rel = getattr(node, relation, None)
if not rel:
return await _log_and_return_error(
Expand Down
Loading