Skip to content

Commit e6e4ee6

Browse files
authored
Make gateway slug use config value (#439)
* fixed gateway separator issue Signed-off-by: Madhav Kandukuri <[email protected]> * Fix pylint and fail a test Signed-off-by: Madhav Kandukuri <[email protected]> * linting fixes Signed-off-by: Madhav Kandukuri <[email protected]> * Linting fixes Signed-off-by: Madhav Kandukuri <[email protected]> * Flake8 fixes Signed-off-by: Madhav Kandukuri <[email protected]> * Ruff fixes Signed-off-by: Madhav Kandukuri <[email protected]> * Fix REST tool addition Signed-off-by: Madhav Kandukuri <[email protected]> * Fix REST tool name Signed-off-by: Madhav Kandukuri <[email protected]> * Allow dot in tool names Signed-off-by: Madhav Kandukuri <[email protected]> * Linting fixes Signed-off-by: Madhav Kandukuri <[email protected]> --------- Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent a11fa60 commit e6e4ee6

15 files changed

+196
-102
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,4 @@ DEBUG=false
263263

264264
# Gateway tool name separator
265265
GATEWAY_TOOL_NAME_SEPARATOR=-
266+
VALID_SLUG_SEPARATOR_REGEXP= r"^(-{1,2}|[_.])$"

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
[flake8]
22
max-line-length = 600
3+
per-file-ignores =
4+
mcpgateway/services/gateway_service.py: DAR401,DAR402

mcpgateway/alembic/versions/b77ca9d2de7e_uuid_pk_and_slug_refactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import uuid
1313

1414
# Third-Party
15-
from alembic import op
1615
import sqlalchemy as sa
1716
from sqlalchemy.orm import Session
17+
from alembic import op
1818

1919
# First-Party
2020
from mcpgateway.config import settings

mcpgateway/alembic/versions/e4fc04d1a442_add_annotations_to_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from typing import Sequence, Union
1212

1313
# Third-Party
14-
from alembic import op
1514
import sqlalchemy as sa
15+
from alembic import op
1616

1717
# revision identifiers, used by Alembic.
1818
revision: str = "e4fc04d1a442"

mcpgateway/alembic/versions/e75490e949b1_add_improved_status_to_tables.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from typing import Sequence, Union
1111

1212
# Third-Party
13-
# Alembic / SQLAlchemy
14-
from alembic import op
1513
import sqlalchemy as sa
14+
from alembic import op
1615

1716
# Revision identifiers.
1817
revision: str = "e75490e949b1"

mcpgateway/bootstrap_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import logging
2626

2727
# Third-Party
28+
from sqlalchemy import create_engine, inspect
2829
from alembic import command
2930
from alembic.config import Config
30-
from sqlalchemy import create_engine, inspect
3131

3232
# First-Party
3333
from mcpgateway.config import settings

mcpgateway/config.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
from functools import lru_cache
5151
from importlib.resources import files
5252
import json
53+
import logging
5354
from pathlib import Path
54-
from typing import Annotated, Any, Dict, List, Optional, Set, Union
55+
import re
56+
from typing import Annotated, Any, Dict, List, Optional, Set, Union, ClassVar
5557

5658
# Third-Party
5759
from fastapi import HTTPException
@@ -61,6 +63,14 @@
6163
from pydantic import field_validator
6264
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
6365

66+
logging.basicConfig(
67+
level=logging.INFO,
68+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
69+
datefmt="%H:%M:%S",
70+
)
71+
72+
logger = logging.getLogger(__name__)
73+
6474

6575
class Settings(BaseSettings):
6676
"""
@@ -237,6 +247,26 @@ def _parse_federation_peers(cls, v):
237247
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore")
238248

239249
gateway_tool_name_separator: str = "-"
250+
valid_slug_separator_regexp: ClassVar[str] = r"^(-{1,2}|[_.])$"
251+
252+
@field_validator("gateway_tool_name_separator")
253+
@classmethod
254+
def must_be_allowed_sep(cls, v: str) -> str:
255+
"""Validate the gateway tool name separator.
256+
257+
Args:
258+
v: The separator value to validate.
259+
260+
Returns:
261+
The validated separator, defaults to '-' if invalid.
262+
"""
263+
if not re.fullmatch(cls.valid_slug_separator_regexp, v):
264+
logger.warning(
265+
f"Invalid gateway_tool_name_separator '{v}'. Must be '-', '--', '_' or '.'. Defaulting to '-'.",
266+
stacklevel=2,
267+
)
268+
return "-"
269+
return v
240270

241271
@property
242272
def api_key(self) -> str:
@@ -393,11 +423,11 @@ def validate_database(self) -> None:
393423
validation_allowed_url_schemes: List[str] = ["http://", "https://", "ws://", "wss://"]
394424

395425
# Character validation patterns
396-
validation_name_pattern: str = r"^[a-zA-Z0-9_\-\s]+$" # Allow spaces for names
426+
validation_name_pattern: str = r"^[a-zA-Z0-9_.\-\s]+$" # Allow spaces for names
397427
validation_identifier_pattern: str = r"^[a-zA-Z0-9_\-\.]+$" # No spaces for IDs
398428
validation_safe_uri_pattern: str = r"^[a-zA-Z0-9_\-.:/?=&%]+$"
399429
validation_unsafe_uri_pattern: str = r'[<>"\'\\]'
400-
validation_tool_name_pattern: str = r"^[a-zA-Z][a-zA-Z0-9_-]*$" # MCP tool naming
430+
validation_tool_name_pattern: str = r"^[a-zA-Z][a-zA-Z0-9._-]*$" # MCP tool naming
401431

402432
# MCP-compliant size limits (configurable via env)
403433
validation_max_name_length: int = 255

mcpgateway/db.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
# Standard
2323
from datetime import datetime, timezone
24-
import re
2524
from typing import Any, Dict, List, Optional
2625
import uuid
2726

@@ -62,6 +61,8 @@
6261
from mcpgateway.models import ResourceContent
6362
from mcpgateway.utils.create_slug import slugify
6463
from mcpgateway.utils.db_isready import wait_for_db_ready
64+
from mcpgateway.validators import SecurityValidator
65+
6566

6667
# ---------------------------------------------------------------------------
6768
# 1. Parse the URL so we can inspect backend ("postgresql", "sqlite", ...)
@@ -128,35 +129,41 @@ def utc_now() -> datetime:
128129
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
129130

130131

132+
def refresh_slugs_on_startup():
133+
"""Refresh slugs for all gateways and names of tools on startup."""
134+
135+
with SessionLocal() as session:
136+
gateways = session.query(Gateway).all()
137+
updated = False
138+
for gateway in gateways:
139+
new_slug = slugify(gateway.name)
140+
if gateway.slug != new_slug:
141+
gateway.slug = new_slug
142+
updated = True
143+
if updated:
144+
session.commit()
145+
146+
tools = session.query(Tool).all()
147+
for tool in tools:
148+
session.expire(tool, ["gateway"])
149+
150+
updated = False
151+
for tool in tools:
152+
if tool.gateway:
153+
new_name = f"{tool.gateway.slug}{settings.gateway_tool_name_separator}{slugify(tool.original_name)}"
154+
else:
155+
new_name = slugify(tool.original_name)
156+
if tool.name != new_name:
157+
tool.name = new_name
158+
updated = True
159+
if updated:
160+
session.commit()
161+
162+
131163
class Base(DeclarativeBase):
132164
"""Base class for all models."""
133165

134166

135-
# TODO: cleanup, not sure why this is commented out?
136-
# # Association table for tools and gateways (federation)
137-
# tool_gateway_table = Table(
138-
# "tool_gateway_association",
139-
# Base.metadata,
140-
# Column("tool_id", String, ForeignKey("tools.id"), primary_key=True),
141-
# Column("gateway_id", String, ForeignKey("gateways.id"), primary_key=True),
142-
# )
143-
144-
# # Association table for resources and gateways (federation)
145-
# resource_gateway_table = Table(
146-
# "resource_gateway_association",
147-
# Base.metadata,
148-
# Column("resource_id", Integer, ForeignKey("resources.id"), primary_key=True),
149-
# Column("gateway_id", String, ForeignKey("gateways.id"), primary_key=True),
150-
# )
151-
152-
# # Association table for prompts and gateways (federation)
153-
# prompt_gateway_table = Table(
154-
# "prompt_gateway_association",
155-
# Base.metadata,
156-
# Column("prompt_id", Integer, ForeignKey("prompts.id"), primary_key=True),
157-
# Column("gateway_id", String, ForeignKey("gateways.id"), primary_key=True),
158-
# )
159-
160167
# Association table for servers and tools
161168
server_tool_association = Table(
162169
"server_tool_association",
@@ -1168,8 +1175,10 @@ def validate_tool_name(mapper, connection, target):
11681175
_ = mapper
11691176
_ = connection
11701177
if hasattr(target, "name"):
1171-
if not re.match(r"^[a-zA-Z0-9_-]+$", target.name):
1172-
raise ValueError(f"Invalid tool name '{target.name}'. Only alphanumeric characters, hyphens, and underscores are allowed.")
1178+
try:
1179+
SecurityValidator.validate_tool_name(target.name)
1180+
except ValueError as e:
1181+
raise ValueError(f"Invalid tool name: {str(e)}")
11731182

11741183

11751184
def validate_prompt_schema(mapper, connection, target):
@@ -1238,3 +1247,34 @@ def init_db():
12381247
wait_for_db_ready(max_tries=int(settings.db_max_retries), interval=int(settings.db_retry_interval_ms) / 1000, sync=True) # Converting ms to s
12391248

12401249
init_db()
1250+
1251+
1252+
@event.listens_for(Gateway, "before_insert")
1253+
def set_gateway_slug(_mapper, _conn, target):
1254+
"""Set the slug for a Gateway before insert.
1255+
1256+
Args:
1257+
_mapper: Mapper
1258+
_conn: Connection
1259+
target: Target Gateway instance
1260+
"""
1261+
1262+
target.slug = slugify(target.name)
1263+
1264+
1265+
@event.listens_for(Tool, "before_insert")
1266+
def set_tool_name(_mapper, _conn, target):
1267+
"""Set the computed name for a Tool before insert.
1268+
1269+
Args:
1270+
_mapper: Mapper
1271+
_conn: Connection
1272+
target: Target Tool instance
1273+
"""
1274+
1275+
sep = settings.gateway_tool_name_separator
1276+
gateway_slug = target.gateway.slug if target.gateway_id else ""
1277+
if gateway_slug:
1278+
target.name = f"{gateway_slug}{sep}{slugify(target.original_name)}"
1279+
else:
1280+
target.name = slugify(target.original_name)

mcpgateway/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from mcpgateway.bootstrap_db import main as bootstrap_db
6161
from mcpgateway.cache import ResourceCache, SessionRegistry
6262
from mcpgateway.config import jsonpath_modifier, settings
63-
from mcpgateway.db import SessionLocal
63+
from mcpgateway.db import refresh_slugs_on_startup, SessionLocal
6464
from mcpgateway.handlers.sampling import SamplingHandler
6565
from mcpgateway.models import (
6666
InitializeRequest,
@@ -216,6 +216,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
216216
await sampling_handler.initialize()
217217
await resource_cache.initialize()
218218
await streamable_http_session.initialize()
219+
refresh_slugs_on_startup()
219220

220221
logger.info("All services initialized successfully")
221222
yield

mcpgateway/services/gateway_service.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
import os
2222
import tempfile
23-
from typing import Any, AsyncGenerator, Dict, List, Optional, Set
23+
from typing import Any, AsyncGenerator, Dict, List, Optional, Set, TYPE_CHECKING
2424
import uuid
2525

2626
# Third-Party
@@ -249,18 +249,26 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
249249
await self._notify_gateway_added(db_gateway)
250250

251251
return GatewayRead.model_validate(gateway)
252-
except GatewayConnectionError as ge:
253-
logger.error(f"GatewayConnectionError: {ge}")
254-
raise ge
255-
except ValueError as ve:
256-
logger.error(f"ValueError: {ve}")
257-
raise ve
258-
except RuntimeError as re:
259-
logger.error(f"RuntimeError: {re}")
260-
raise re
261-
except BaseException as other:
262-
logger.error(f"Other errors: {other}")
263-
raise other
252+
except* GatewayConnectionError as ge:
253+
if TYPE_CHECKING:
254+
ge: ExceptionGroup[GatewayConnectionError]
255+
logger.error(f"GatewayConnectionError in group: {ge.exceptions}")
256+
raise ge.exceptions[0]
257+
except* ValueError as ve:
258+
if TYPE_CHECKING:
259+
ve: ExceptionGroup[ValueError]
260+
logger.error(f"ValueErrors in group: {ve.exceptions}")
261+
raise ve.exceptions[0]
262+
except* RuntimeError as re:
263+
if TYPE_CHECKING:
264+
re: ExceptionGroup[RuntimeError]
265+
logger.error(f"RuntimeErrors in group: {re.exceptions}")
266+
raise re.exceptions[0]
267+
except* BaseException as other: # catches every other sub-exception
268+
if TYPE_CHECKING:
269+
other: ExceptionGroup[BaseException]
270+
logger.error(f"Other grouped errors: {other.exceptions}")
271+
raise other.exceptions[0]
264272

265273
async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
266274
"""List all registered gateways.

0 commit comments

Comments
 (0)