Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions .github/workflows/generate-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
pip install -e .

- name: Run generator script
- name: Run generator script (OpenAPI YAML)
run: |
mkdir -p tests/out/
python generator.py \
Expand All @@ -37,26 +37,40 @@ jobs:
--api-url http://localhost:8000/api \
--api-token "test-token"

- name: Run generator via CLI tool
- name: Run generator script (JSON specifications)
run: |
python generator.py \
tests/test_fixtures/ \
--output-dir ./tests/out/ \
--api-url http://localhost:8000/api \
--api-token "test-token"

- name: Run generator via CLI tool (OpenAPI YAML)
run: |
mkdir -p tests/out_cli/
mcp-generator tests/openapi.yaml --output-dir ./tests/out_cli/ --api-url http://localhost:8000/api --api-token "test-token"

- name: Run generator via Python module
- name: Run generator via CLI tool (JSON specifications)
run: |
mcp-generator tests/test_fixtures/ --output-dir ./tests/out_cli/ --api-url http://localhost:8000/api --api-token "test-token"

- name: Run generator via Python module (OpenAPI YAML)
run: |
mkdir -p tests/out_module/
python -m openapi_mcp_generator.cli tests/openapi.yaml --output-dir ./tests/out_module/ --api-url http://localhost:8000/api --api-token "test-token"

- name: Verify output directory exists
- name: Run generator via Python module (JSON specifications)
run: |
ls ./tests/out/ | grep "openapi-mcp-reference-test-api-"
shell: bash
python -m openapi_mcp_generator.cli tests/test_fixtures/ --output-dir ./tests/out_module/ --api-url http://localhost:8000/api --api-token "test-token"

- name: Verify generated mcp_server.py
- name: Verify output directories exist
run: |
GENERATED_DIR=$(ls ./tests/out/ | grep "openapi-mcp-reference-test-api-" | head -n 1)
grep -q "async def getItems(id: int, verbose: bool, limit: int, ctx: Context) -> str:" ./tests/out/$GENERATED_DIR/mcp_server.py
grep -q "def get_BadRequestDetails_schema" ./tests/out/$GENERATED_DIR/mcp_server.py
echo "Checking tests/out/ directory:"
ls ./tests/out/ | grep -E "(openapi-mcp-reference-test-api-|openapi-mcp-generated-api-)"
echo "Checking tests/out_cli/ directory:"
ls ./tests/out_cli/ | grep -E "(openapi-mcp-reference-test-api-|openapi-mcp-generated-api-)"
echo "Checking tests/out_module/ directory:"
ls ./tests/out_module/ | grep -E "(openapi-mcp-reference-test-api-|openapi-mcp-generated-api-)"
shell: bash

- name: Install pytest
Expand Down
21 changes: 17 additions & 4 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,33 @@

def parse_openapi_spec(filepath: str) -> Dict[str, Any]:
"""
Parse an OpenAPI specification file.
Parse an OpenAPI specification file or directory.

Args:
filepath: Path to the OpenAPI YAML file
filepath: Path to the OpenAPI YAML file or directory

Returns:
Dictionary containing the parsed OpenAPI specification

Raises:
SystemExit: If the file cannot be read or parsed
SystemExit: If the file/directory cannot be read or parsed
"""
# Try to use the modular parser first
if USE_MODULAR:
try:
from openapi_mcp_generator.parser import parse_openapi_spec as modular_parse
return modular_parse(filepath)
except ImportError:
pass

# Fallback to original implementation for single YAML files only
if not os.path.exists(filepath):
print(f"Error: OpenAPI specification file not found: {filepath}")
sys.exit(1)

if os.path.isdir(filepath):
print(f"Error: Directory processing requires the modular parser. Please install the package.")
sys.exit(1)

try:
with open(filepath, 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -101,7 +114,7 @@ def generate_tool_definitions(spec: Dict[str, Any]) -> str:
# Get parameters
parameters_definitions = []
for param_obj in operation.get('parameters', []):
actual_param = resolve_ref(param_obj, spec) if '$ref' in param_obj else param_obj
actual_param = resolve_ref(spec, param_obj['$ref']) if '$ref' in param_obj else param_obj
if not actual_param:
print(f"Warning: Could not resolve parameter reference: {param_obj}")
continue
Expand Down
4 changes: 3 additions & 1 deletion openapi_mcp_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
"""

from .generator import generate_mcp_server
from .parser import parse_openapi_spec, sanitize_description
from .parser import parse_openapi_spec, sanitize_description, sanitize_identifier, escape_string_literal
from .generators import generate_tool_definitions, generate_resource_definitions

__all__ = [
'generate_mcp_server',
'parse_openapi_spec',
'sanitize_description',
'sanitize_identifier',
'escape_string_literal',
'generate_tool_definitions',
'generate_resource_definitions',
]
147 changes: 118 additions & 29 deletions openapi_mcp_generator/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""

import yaml
from typing import Dict, Any, List
from .parser import sanitize_description, resolve_ref
from typing import Dict, Any, List, Tuple
from .parser import sanitize_description, sanitize_identifier, escape_string_literal, resolve_ref


def generate_tool_definitions(spec: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -49,14 +49,17 @@ def _generate_tool(spec: Dict[str, Any], path: str, method: str, operation: Dict
if 'operationId' not in operation:
return ""

operation_id = operation['operationId']
description = sanitize_description(operation.get('description', f"{method.upper()} {path}"))
operation_id = sanitize_identifier(operation['operationId'])
description = escape_string_literal(operation.get('description', f"{method.upper()} {path}"))

# Get parameters
parameters_definitions = _get_parameter_definitions(spec, operation)
# Get parameters separated by required vs optional
required_params, optional_params = _get_parameter_definitions(spec, operation)

# Add ctx parameter
parameters_definitions.append("ctx: Context")
# Combine parameters in correct order: required params, ctx, optional params
parameters_definitions = required_params + ["ctx: Context"] + optional_params

# Generate parameter processing code
param_processing = _generate_parameter_processing(spec, operation, path)

# Create tool function
return f"""
Expand All @@ -67,18 +70,13 @@ async def {operation_id}({', '.join(parameters_definitions)}) -> str:
\"\"\"
async with await get_http_client() as client:
try:
# Build the URL with path parameters
url = "{path}"

# Extract query parameters
query_params = {{}}
# ... build query params from function args
{param_processing}

# Make the request
response = await client.{method}(
url,
params=query_params,
# Add other parameters as needed
json=request_body if request_body else None
)

# Check if the request was successful
Expand All @@ -94,18 +92,21 @@ async def {operation_id}({', '.join(parameters_definitions)}) -> str:
"""


def _get_parameter_definitions(spec: Dict[str, Any], operation: Dict[str, Any]) -> List[str]:
def _get_parameter_definitions(spec: Dict[str, Any], operation: Dict[str, Any]) -> Tuple[List[str], List[str]]:
"""
Get parameter definitions for a tool function.
Get parameter definitions for a tool function, separated by required vs optional.

Args:
spec: The parsed OpenAPI specification
operation: The operation definition

Returns:
List of parameter definition strings
Tuple of (required_parameters, optional_parameters) definition strings
"""
parameters_definitions = []
required_params = []
optional_params = []
seen_params = set() # Track seen parameter names to avoid duplicates

for param_obj in operation.get('parameters', []):
actual_param = {}
if '$ref' in param_obj:
Expand All @@ -118,12 +119,34 @@ def _get_parameter_definitions(spec: Dict[str, Any], operation: Dict[str, Any])
print(f"Warning: Skipping parameter due to missing name or unresolved reference: {param_obj}")
continue

param_name = actual_param['name']
param_name = sanitize_identifier(actual_param['name'])

# Handle duplicate parameter names
original_param_name = param_name
counter = 1
while param_name in seen_params:
param_name = f"{original_param_name}_{counter}"
counter += 1

seen_params.add(param_name)
param_type = _get_param_type(actual_param)

parameters_definitions.append(f"{param_name}: {param_type}")
# Separate required and optional parameters
if actual_param.get('required', False):
required_params.append(f"{param_name}: {param_type}")
else:
# Add default value for optional parameters
if param_type == 'bool':
param_type = f"{param_type} = False"
elif param_type == 'str':
param_type = f"{param_type} = ''"
elif param_type in ['int', 'float']:
param_type = f"{param_type} = 0"
else:
param_type = f"Optional[{param_type}] = None"
optional_params.append(f"{param_name}: {param_type}")

return parameters_definitions
return required_params, optional_params


def _get_param_type(param: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -151,6 +174,68 @@ def _get_param_type(param: Dict[str, Any]) -> str:
return param_type


def _generate_parameter_processing(spec: Dict[str, Any], operation: Dict[str, Any], path: str) -> str:
"""
Generate parameter processing code for a tool function.

Args:
spec: The parsed OpenAPI specification
operation: The operation definition
path: The API path

Returns:
String containing parameter processing code
"""
lines = []
lines.append(" # Build the URL with path parameters")
lines.append(f" url = \"{path}\"")
lines.append("")
lines.append(" # Extract query parameters")
lines.append(" query_params = {}")
lines.append(" request_body = None")
lines.append("")

# Process parameters
seen_params = set()
for param_obj in operation.get('parameters', []):
actual_param = {}
if '$ref' in param_obj:
ref_path = param_obj['$ref']
actual_param = resolve_ref(spec, ref_path)
else:
actual_param = param_obj

if not actual_param or 'name' not in actual_param:
continue

param_name = sanitize_identifier(actual_param['name'])
original_param_name = param_name

# Handle duplicate parameter names
counter = 1
while param_name in seen_params:
param_name = f"{original_param_name}_{counter}"
counter += 1
seen_params.add(param_name)

param_in = actual_param.get('in', 'query')
original_name = actual_param['name']

if param_in == 'path':
# Replace path parameters in URL
lines.append(f" if {param_name} is not None:")
lines.append(f" url = url.replace('{{{original_name}}}', str({param_name}))")
elif param_in == 'query':
# Add to query parameters
lines.append(f" if {param_name} is not None:")
lines.append(f" query_params['{original_name}'] = {param_name}")
elif param_in == 'header':
# We'll handle headers separately if needed
pass

return "\n".join(lines)


def generate_resource_definitions(spec: Dict[str, Any]) -> str:
"""
Generate MCP resource definitions from OpenAPI components.
Expand Down Expand Up @@ -185,9 +270,9 @@ def _generate_api_info_resource(spec: Dict[str, Any]) -> str:
String containing the generated resource definition
"""
info = spec.get('info', {})
api_title = info.get('title', 'API')
api_version = info.get('version', '1.0.0')
api_description = sanitize_description(info.get('description', 'API description'))
api_title = escape_string_literal(info.get('title', 'API'))
api_version = escape_string_literal(info.get('version', '1.0.0'))
api_description = escape_string_literal(info.get('description', 'API description'))

return f"""
@mcp.resource("api://info")
Expand Down Expand Up @@ -216,14 +301,18 @@ def _generate_schema_resources(spec: Dict[str, Any]) -> List[str]:
schema_resources = []

for schema_name, schema in spec.get('components', {}).get('schemas', {}).items():
safe_schema_name = sanitize_identifier(schema_name)
escaped_schema_name = escape_string_literal(schema_name)
schema_yaml = escape_string_literal(yaml.dump(schema, default_flow_style=False))

resource_def = f"""
@mcp.resource("schema://{schema_name}")
def get_{schema_name}_schema() -> str:
@mcp.resource("schema://{escaped_schema_name}")
def get_{safe_schema_name}_schema() -> str:
\"\"\"
Get the {schema_name} schema definition
Get the {escaped_schema_name} schema definition
\"\"\"
return \"\"\"
{yaml.dump(schema, default_flow_style=False)}
{schema_yaml}
\"\"\"
"""
schema_resources.append(resource_def)
Expand Down
Loading