Skip to content
36 changes: 36 additions & 0 deletions airbyte/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,39 @@ def _str_to_bool(value: str) -> bool:
- https://docs.airbyte.com/api-documentation#configuration-api-deprecated
- https://github.com/airbytehq/airbyte-platform-internal/blob/master/oss/airbyte-api/server-api/src/main/openapi/config.yaml
"""

# MCP (Model Context Protocol) Constants

MCP_TOOL_DOMAINS: list[str] = ["cloud", "local", "registry"]
"""Valid MCP tool domains available in the server.

- `cloud`: Tools for managing Airbyte Cloud resources (sources, destinations, connections)
- `local`: Tools for local operations (connector validation, caching, SQL queries)
- `registry`: Tools for querying the Airbyte connector registry
"""

AIRBYTE_MCP_DOMAINS: list[str] | None = [
d.strip().lower() for d in os.getenv("AIRBYTE_MCP_DOMAINS", "").split(",") if d.strip()
] or None
"""Enabled MCP tool domains from the `AIRBYTE_MCP_DOMAINS` environment variable.

Accepts a comma-separated list of domain names (e.g., "registry,cloud").
If set, only tools from these domains will be advertised by the MCP server.
If not set (None), all domains are enabled by default.

Values are case-insensitive and whitespace is trimmed.
"""

AIRBYTE_MCP_DOMAINS_DISABLED: list[str] | None = [
d.strip().lower() for d in os.getenv("AIRBYTE_MCP_DOMAINS_DISABLED", "").split(",") if d.strip()
] or None
"""Disabled MCP tool domains from the `AIRBYTE_MCP_DOMAINS_DISABLED` environment variable.

Accepts a comma-separated list of domain names (e.g., "registry").
Tools from these domains will not be advertised by the MCP server.

When both `AIRBYTE_MCP_DOMAINS` and `AIRBYTE_MCP_DOMAINS_DISABLED` are set,
the disabled list takes precedence (subtracts from the enabled list).

Values are case-insensitive and whitespace is trimmed.
"""
89 changes: 86 additions & 3 deletions airbyte/mcp/_tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@

import inspect
import os
import warnings
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Literal, TypeVar

from airbyte.constants import (
AIRBYTE_MCP_DOMAINS,
AIRBYTE_MCP_DOMAINS_DISABLED,
MCP_TOOL_DOMAINS,
)
from airbyte.mcp._annotations import (
DESTRUCTIVE_HINT,
IDEMPOTENT_HINT,
Expand All @@ -28,6 +35,7 @@
AIRBYTE_CLOUD_MCP_SAFE_MODE = os.environ.get("AIRBYTE_CLOUD_MCP_SAFE_MODE", "1").strip() != "0"
AIRBYTE_CLOUD_WORKSPACE_ID_IS_SET = bool(os.environ.get("AIRBYTE_CLOUD_WORKSPACE_ID", "").strip())


_REGISTERED_TOOLS: list[tuple[Callable[..., Any], dict[str, Any]]] = []
_GUIDS_CREATED_IN_SESSION: set[str] = set()

Expand Down Expand Up @@ -66,6 +74,76 @@ def check_guid_created_in_session(guid: str) -> None:
)


@lru_cache(maxsize=1)
def _resolve_mcp_domain_filters() -> tuple[set[str], set[str]]:
"""Resolve MCP domain filters from environment variables.

This function is cached to ensure warnings are only emitted once per process.

Returns:
Tuple of (enabled_domains, disabled_domains) as sets.
If an env var is not set, the corresponding set will be empty.
"""
known_domains = set(MCP_TOOL_DOMAINS)
enabled = set(AIRBYTE_MCP_DOMAINS or [])
disabled = set(AIRBYTE_MCP_DOMAINS_DISABLED or [])

# Check for unknown domains and warn
unknown_enabled = enabled - known_domains
unknown_disabled = disabled - known_domains

if unknown_enabled or unknown_disabled:
parts: list[str] = []
if unknown_enabled:
parts.append(
f"AIRBYTE_MCP_DOMAINS contains unknown domain(s): {sorted(unknown_enabled)}"
)
if unknown_disabled:
parts.append(
"AIRBYTE_MCP_DOMAINS_DISABLED contains unknown domain(s): "
f"{sorted(unknown_disabled)}"
)
known_list = ", ".join(sorted(known_domains))
warning_message = "; ".join(parts) + f". Known MCP domains are: [{known_list}]."
warnings.warn(warning_message, stacklevel=3)

return enabled, disabled


def is_domain_enabled(domain: str) -> bool:
"""Check if a domain is enabled based on AIRBYTE_MCP_DOMAINS and AIRBYTE_MCP_DOMAINS_DISABLED.

The logic is:
- If neither env var is set: all domains are enabled
- If only AIRBYTE_MCP_DOMAINS is set: only those domains are enabled
- If only AIRBYTE_MCP_DOMAINS_DISABLED is set: all domains except those are enabled
- If both are set: disabled domains subtract from enabled domains

Args:
domain: The domain to check (e.g., "cloud", "local", "registry")

Returns:
True if the domain is enabled, False otherwise
"""
enabled, disabled = _resolve_mcp_domain_filters()
domain_lower = domain.lower()

# If neither env var is set, all domains are enabled
if not enabled and not disabled:
return True

# If only disabled list is set, enable all except disabled
if not enabled and disabled:
return domain_lower not in disabled

# If only enabled list is set, only enable those domains
if enabled and not disabled:
return domain_lower in enabled

# Both are set: disabled list subtracts from enabled list
return domain_lower in enabled and domain_lower not in disabled


def should_register_tool(annotations: dict[str, Any]) -> bool:
"""Check if a tool should be registered based on mode settings.

Expand All @@ -75,10 +153,15 @@ def should_register_tool(annotations: dict[str, Any]) -> bool:
Returns:
True if the tool should be registered, False if it should be filtered out
"""
if annotations.get("domain") != "cloud":
return True
domain = annotations.get("domain")
domain_normalized = domain.lower() if isinstance(domain, str) else None

# Check domain filtering first
if domain_normalized and not is_domain_enabled(domain_normalized):
return False

if AIRBYTE_CLOUD_MCP_READONLY_MODE:
# Cloud-specific readonly mode check (case-insensitive)
if domain_normalized == "cloud" and AIRBYTE_CLOUD_MCP_READONLY_MODE:
is_readonly = annotations.get(READ_ONLY_HINT, False)
if not is_readonly:
return False
Expand Down
153 changes: 153 additions & 0 deletions tests/unit_tests/test_mcp_tool_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""Unit tests for MCP tool utility functions."""

from __future__ import annotations

import importlib
import warnings
from unittest.mock import patch

import pytest

import airbyte.constants as constants
import airbyte.mcp._tool_utils as tool_utils
from airbyte.mcp._annotations import READ_ONLY_HINT

# (enabled, disabled, domain, readonly_mode, is_readonly, domain_enabled, should_register)
_DOMAIN_CASES = [
(None, None, "cloud", False, False, True, True),
(None, None, "registry", False, False, True, True),
(None, None, "local", False, False, True, True),
(["cloud"], None, "cloud", False, False, True, True),
(["cloud"], None, "registry", False, False, False, False),
(None, ["registry"], "registry", False, False, False, False),
(None, ["registry"], "cloud", False, False, True, True),
(["registry", "cloud"], ["registry"], "cloud", False, False, True, True),
(["registry", "cloud"], ["registry"], "registry", False, False, False, False),
(["cloud"], ["registry"], "local", False, False, False, False),
(["CLOUD"], None, "cloud", False, False, True, True),
(["cloud"], None, "CLOUD", False, False, True, True),
(None, None, "cloud", True, False, True, False),
(None, None, "cloud", True, True, True, True),
(None, None, "registry", True, False, True, True),
(["cloud"], None, "cloud", True, True, True, True),
(["registry"], None, "cloud", True, True, False, False),
]


@pytest.mark.parametrize(
"enabled,disabled,domain,readonly_mode,is_readonly,domain_enabled,should_register",
_DOMAIN_CASES,
)
def test_domain_logic(
enabled: list[str] | None,
disabled: list[str] | None,
domain: str,
readonly_mode: bool,
is_readonly: bool,
domain_enabled: bool,
should_register: bool,
) -> None:
norm_enabled = [d.lower() for d in enabled] if enabled else None
norm_disabled = [d.lower() for d in disabled] if disabled else None
with (
patch("airbyte.mcp._tool_utils.AIRBYTE_MCP_DOMAINS", norm_enabled),
patch("airbyte.mcp._tool_utils.AIRBYTE_MCP_DOMAINS_DISABLED", norm_disabled),
patch("airbyte.mcp._tool_utils.AIRBYTE_CLOUD_MCP_READONLY_MODE", readonly_mode),
):
tool_utils._resolve_mcp_domain_filters.cache_clear()
assert tool_utils.is_domain_enabled(domain) == domain_enabled
assert (
tool_utils.should_register_tool({
"domain": domain,
READ_ONLY_HINT: is_readonly,
})
== should_register
)


# (env_var, attr, env_value, expected)
_ENV_PARSE_CASES = [
("AIRBYTE_MCP_DOMAINS", "AIRBYTE_MCP_DOMAINS", "", None),
("AIRBYTE_MCP_DOMAINS", "AIRBYTE_MCP_DOMAINS", "cloud", ["cloud"]),
(
"AIRBYTE_MCP_DOMAINS",
"AIRBYTE_MCP_DOMAINS",
"registry,cloud",
["registry", "cloud"],
),
(
"AIRBYTE_MCP_DOMAINS",
"AIRBYTE_MCP_DOMAINS",
"registry, cloud",
["registry", "cloud"],
),
(
"AIRBYTE_MCP_DOMAINS",
"AIRBYTE_MCP_DOMAINS",
"REGISTRY,CLOUD",
["registry", "cloud"],
),
(
"AIRBYTE_MCP_DOMAINS",
"AIRBYTE_MCP_DOMAINS",
"registry,,cloud",
["registry", "cloud"],
),
("AIRBYTE_MCP_DOMAINS_DISABLED", "AIRBYTE_MCP_DOMAINS_DISABLED", "", None),
(
"AIRBYTE_MCP_DOMAINS_DISABLED",
"AIRBYTE_MCP_DOMAINS_DISABLED",
"registry",
["registry"],
),
(
"AIRBYTE_MCP_DOMAINS_DISABLED",
"AIRBYTE_MCP_DOMAINS_DISABLED",
"registry,local",
["registry", "local"],
),
]


@pytest.mark.parametrize("env_var,attr,env_value,expected", _ENV_PARSE_CASES)
def test_env_parsing(
env_var: str, attr: str, env_value: str, expected: list[str] | None
) -> None:
with patch.dict("os.environ", {env_var: env_value}, clear=False):
importlib.reload(constants)
assert getattr(constants, attr) == expected
importlib.reload(constants)


# (env_var, env_value, warning_fragment)
_WARNING_CASES = [
(
"AIRBYTE_MCP_DOMAINS",
"cloud,invalid",
"AIRBYTE_MCP_DOMAINS contains unknown domain(s)",
),
(
"AIRBYTE_MCP_DOMAINS_DISABLED",
"registry,fake",
"AIRBYTE_MCP_DOMAINS_DISABLED contains unknown domain(s)",
),
]


@pytest.mark.parametrize("env_var,env_value,fragment", _WARNING_CASES)
def test_unknown_domain_warning(env_var: str, env_value: str, fragment: str) -> None:
with (
patch.dict("os.environ", {env_var: env_value}, clear=False),
warnings.catch_warnings(record=True) as caught,
):
warnings.simplefilter("always")
importlib.reload(constants)
importlib.reload(tool_utils)
tool_utils._resolve_mcp_domain_filters.cache_clear()
tool_utils._resolve_mcp_domain_filters()
messages = [str(w.message) for w in caught]
assert any(fragment in m for m in messages)
assert any("Known MCP domains are:" in m for m in messages)
importlib.reload(constants)
importlib.reload(tool_utils)