Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ def post(self):
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
Expand All @@ -881,6 +882,7 @@ def post(self):
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
)
)

Expand All @@ -898,6 +900,7 @@ def put(self):
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
Expand All @@ -915,6 +918,7 @@ def put(self):
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}

Expand Down Expand Up @@ -951,6 +955,9 @@ def post(self):
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
Expand Down
6 changes: 6 additions & 0 deletions api/core/tools/entities/api_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class ToolProviderApiEntity(BaseModel):
server_url: Optional[str] = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
headers: Optional[dict[str, str]] = Field(default=None, description="The headers of the MCP tool")

@field_validator("tools", mode="before")
@classmethod
Expand All @@ -65,6 +68,9 @@ def to_dict(self) -> dict:
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(self.optional_field("headers", self.headers))
return {
"id": self.id,
"author": self.author,
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/mcp_tool/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""add_headers_to_mcp_provider

Revision ID: c20211f18133
Revises: 8d289573e1da
Create Date: 2025-08-29 10:07:54.163626

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'c20211f18133'
down_revision = '8d289573e1da'
branch_labels = None
depends_on = None


def upgrade():
# Add headers column to tool_mcp_providers table
op.add_column('tool_mcp_providers', sa.Column('headers', sa.Text(), nullable=True))

# Add comment to the column
op.execute("COMMENT ON COLUMN tool_mcp_providers.headers IS 'Custom HTTP headers for MCP server requests (JSON format)'")


def downgrade():
# Remove headers column from tool_mcp_providers table
op.drop_column('tool_mcp_providers', 'headers')
12 changes: 12 additions & 0 deletions api/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class MCPToolProvider(Base):
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
# Custom headers for MCP server requests
headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)

def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
Expand Down Expand Up @@ -310,6 +312,16 @@ def provider_icon(self) -> dict[str, str] | str:
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)

@property
def decrypted_headers(self) -> dict:
"""Get decrypted headers for MCP server requests."""
try:
if not self.headers:
return {}
return cast(dict, json.loads(self.headers))
except Exception:
return {}

@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
Expand Down
28 changes: 27 additions & 1 deletion api/services/tools/mcp_tools_manage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def create_mcp_provider(
server_identifier: str,
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
Expand Down Expand Up @@ -95,6 +96,7 @@ def create_mcp_provider(
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
headers=json.dumps(headers) if headers else None,
)
db.session.add(mcp_tool)
db.session.commit()
Expand All @@ -118,9 +120,21 @@ def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> T
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout

try:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
Expand Down Expand Up @@ -172,6 +186,7 @@ def update_mcp_provider(
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)

Expand Down Expand Up @@ -207,6 +222,8 @@ def update_mcp_provider(
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
mcp_provider.headers = json.dumps(headers) if headers else None
db.session.commit()
except IntegrityError as e:
db.session.rollback()
Expand Down Expand Up @@ -242,13 +259,22 @@ def update_mcp_provider_credentials(

@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout

try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {
Expand Down
3 changes: 3 additions & 0 deletions api/services/tools/tools_transform_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool =
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
headers=db_provider.decrypted_headers,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,14 @@ def test_list_mcp_tool_from_remote_server_success(

# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp", mcp_provider.id, tenant.id, authed=False, for_list=True
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)

def test_list_mcp_tool_from_remote_server_auth_error(
Expand Down Expand Up @@ -1181,6 +1188,11 @@ def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_
db_session_with_containers, mock_external_service_dependencies
)

# Create MCP provider first
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)

# Mock MCPClient and its context manager
mock_tools = [
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_1", "description": "Test tool 1"}})(),
Expand All @@ -1194,7 +1206,7 @@ def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_

# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", "test_provider_id", tenant.id
"https://example.com/mcp", mcp_provider.id, tenant.id
)

# Assert: Verify the expected outcomes
Expand All @@ -1213,7 +1225,14 @@ def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_

# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp", "test_provider_id", tenant.id, authed=False, for_list=True
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)

def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
Expand All @@ -1231,6 +1250,11 @@ def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mo
db_session_with_containers, mock_external_service_dependencies
)

# Create MCP provider first
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)

# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPAuthError
Expand All @@ -1240,7 +1264,7 @@ def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mo

# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", "test_provider_id", tenant.id
"https://example.com/mcp", mcp_provider.id, tenant.id
)

# Assert: Verify the expected outcomes
Expand All @@ -1265,6 +1289,11 @@ def test_re_connect_mcp_provider_connection_error(
db_session_with_containers, mock_external_service_dependencies
)

# Create MCP provider first
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)

# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPError
Expand All @@ -1274,4 +1303,4 @@ def test_re_connect_mcp_provider_connection_error(

# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", "test_provider_id", tenant.id)
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
Loading