Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ def post(self):
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
Expand All @@ -881,6 +882,7 @@ def post(self):
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
)
)

Expand All @@ -898,6 +900,7 @@ def put(self):
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
Expand All @@ -915,6 +918,7 @@ def put(self):
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}

Expand Down Expand Up @@ -951,6 +955,9 @@ def post(self):
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
Expand Down
8 changes: 8 additions & 0 deletions api/core/tools/entities/api_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class ToolProviderApiEntity(BaseModel):
server_url: Optional[str] = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool")
original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool")

@field_validator("tools", mode="before")
@classmethod
Expand All @@ -65,6 +69,10 @@ def to_dict(self):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {
"id": self.id,
"author": self.author,
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/mcp_tool/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def from_db(cls, db_provider: MCPToolProvider) -> Self:
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""add_headers_to_mcp_provider

Revision ID: c20211f18133
Revises: 8d289573e1da
Create Date: 2025-08-29 10:07:54.163626

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'c20211f18133'
down_revision = 'b95962a3885c'
branch_labels = None
depends_on = None


def upgrade():
# Add encrypted_headers column to tool_mcp_providers table
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))


def downgrade():
# Remove encrypted_headers column from tool_mcp_providers table
op.drop_column('tool_mcp_providers', 'encrypted_headers')
58 changes: 58 additions & 0 deletions api/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class MCPToolProvider(Base):
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
# encrypted headers for MCP server requests
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)

def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
Expand Down Expand Up @@ -310,6 +312,62 @@ def provider_icon(self) -> dict[str, str] | str:
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)

@property
def decrypted_headers(self) -> dict[str, Any]:
"""Get decrypted headers for MCP server requests."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter

try:
if not self.encrypted_headers:
return {}

headers_data = json.loads(self.encrypted_headers)

# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]

encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)

result = encrypter_instance.decrypt(headers_data)
return result
except Exception:
return {}

@property
def masked_headers(self) -> dict[str, Any]:
"""Get masked headers for frontend display."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter

try:
if not self.encrypted_headers:
return {}

headers_data = json.loads(self.encrypted_headers)

# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]

encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)

# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
return result
except Exception:
return {}

@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
Expand Down
71 changes: 69 additions & 2 deletions api/services/tools/mcp_tools_manage_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import json
from datetime import datetime
from typing import Any
from typing import Any, cast

from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
Expand All @@ -27,6 +27,36 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""

@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.

Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption

Returns:
Dictionary with all headers encrypted
"""
if not headers:
return {}

from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter

# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]

encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)

return cast(dict[str, str], encrypter_instance.encrypt(headers))

@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
Expand Down Expand Up @@ -61,6 +91,7 @@ def create_mcp_provider(
server_identifier: str,
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
Expand All @@ -83,6 +114,12 @@ def create_mcp_provider(
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)

mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
Expand All @@ -95,6 +132,7 @@ def create_mcp_provider(
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
Expand All @@ -118,9 +156,21 @@ def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> T
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout

try:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
Expand Down Expand Up @@ -172,6 +222,7 @@ def update_mcp_provider(
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)

Expand Down Expand Up @@ -207,6 +258,13 @@ def update_mcp_provider(
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Encrypt headers
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
mcp_provider.encrypted_headers = None
db.session.commit()
except IntegrityError as e:
db.session.rollback()
Expand Down Expand Up @@ -242,13 +300,22 @@ def update_mcp_provider_credentials(

@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout

try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {
Expand Down
4 changes: 4 additions & 0 deletions api/services/tools/tools_transform_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool =
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
)

@staticmethod
Expand Down
Loading
Loading