diff --git a/samcli/cli/command.py b/samcli/cli/command.py index a5ba816847..387b931e09 100644 --- a/samcli/cli/command.py +++ b/samcli/cli/command.py @@ -22,6 +22,7 @@ "samcli.commands.validate.validate", "samcli.commands.build", "samcli.commands.local.local", + "samcli.commands.generate.generate", "samcli.commands.package", "samcli.commands.deploy", "samcli.commands.delete", @@ -170,6 +171,10 @@ def format_commands(self, ctx: click.Context, formatter: RootCommandHelpTextForm name="validate", text=SAM_CLI_COMMANDS.get("validate", ""), ), + RowDefinition( + name="generate", + text=SAM_CLI_COMMANDS.get("generate", ""), + ), RowDefinition( name="sync", text=SAM_CLI_COMMANDS.get("sync", ""), diff --git a/samcli/cli/root/command_list.py b/samcli/cli/root/command_list.py index dc5316b5ca..335332973c 100644 --- a/samcli/cli/root/command_list.py +++ b/samcli/cli/root/command_list.py @@ -7,6 +7,7 @@ "validate": "Validate an AWS SAM template.", "build": "Build your AWS serverless function code.", "local": "Run your AWS serverless function locally.", + "generate": "Generate artifacts from SAM templates.", "remote": "Invoke or send an event to cloud resources in your AWS Cloudformation stack.", "package": "Package an AWS SAM application.", "deploy": "Deploy an AWS SAM application.", diff --git a/samcli/commands/generate/__init__.py b/samcli/commands/generate/__init__.py new file mode 100644 index 0000000000..aca11a33e9 --- /dev/null +++ b/samcli/commands/generate/__init__.py @@ -0,0 +1,3 @@ +""" +Generate Command Group +""" diff --git a/samcli/commands/generate/generate.py b/samcli/commands/generate/generate.py new file mode 100644 index 0000000000..39ce920dcf --- /dev/null +++ b/samcli/commands/generate/generate.py @@ -0,0 +1,29 @@ +""" +Main CLI group for 'generate' commands +""" + +import click + +from samcli.cli.main import common_options, pass_context, print_cmdline_args +from samcli.commands._utils.command_exception_handler import command_exception_handler +from samcli.commands.generate.openapi.command import cli as openapi_cli + + +@click.group() +@common_options +@pass_context +@print_cmdline_args +@command_exception_handler +def cli(ctx): + """ + Generate artifacts from SAM templates. + + This command group provides subcommands to generate various artifacts + from your SAM templates, such as OpenAPI specifications, CloudFormation + templates, and more. + """ + pass + + +# Add openapi subcommand +cli.add_command(openapi_cli) diff --git a/samcli/commands/generate/openapi/__init__.py b/samcli/commands/generate/openapi/__init__.py new file mode 100644 index 0000000000..477a6c883e --- /dev/null +++ b/samcli/commands/generate/openapi/__init__.py @@ -0,0 +1,3 @@ +""" +OpenAPI Generation Command +""" diff --git a/samcli/commands/generate/openapi/command.py b/samcli/commands/generate/openapi/command.py new file mode 100644 index 0000000000..ff68403a23 --- /dev/null +++ b/samcli/commands/generate/openapi/command.py @@ -0,0 +1,124 @@ +""" +CLI command for "generate openapi" command +""" + +import click + +from samcli.cli.cli_config_file import ConfigProvider, configuration_option +from samcli.cli.main import aws_creds_options, common_options, pass_context, print_cmdline_args +from samcli.commands._utils.command_exception_handler import command_exception_handler +from samcli.commands._utils.options import parameter_override_option, template_click_option +from samcli.lib.telemetry.metric import track_command +from samcli.lib.utils.version_checker import check_newer_version + +SHORT_HELP = "Generate OpenAPI specification from SAM template." + +DESCRIPTION = """ + Generate an OpenAPI (Swagger) specification document from a SAM template. + + SAM automatically generates OpenAPI documents for your APIs at deploy time. + This command allows you to access that generated OpenAPI document as part of + your build process, enabling integration with tools like swagger-codegen, + OpenAPI Generator, and other API documentation/client generation tools. +""" + + +@click.command( + "openapi", + short_help=SHORT_HELP, + help=DESCRIPTION, + context_settings={"max_content_width": 120}, +) +@configuration_option(provider=ConfigProvider(section="parameters")) +@template_click_option(include_build=False) +@click.option( + "--api-logical-id", + required=False, + type=str, + help="Logical ID of the API resource to generate OpenAPI for. " + "Required when template contains multiple APIs. " + "Defaults to auto-detection when only one API exists.", +) +@click.option( + "--output-file", + "-o", + required=False, + type=click.Path(), + help="Path to output file for generated OpenAPI document. " "If not specified, outputs to stdout.", +) +@click.option( + "--format", + type=click.Choice(["yaml", "json"], case_sensitive=False), + default="yaml", + help="Output format for the OpenAPI document.", + show_default=True, +) +@click.option( + "--openapi-version", + type=click.Choice(["2.0", "3.0"], case_sensitive=False), + default="3.0", + help="OpenAPI specification version (2.0 = Swagger, 3.0 = OpenAPI).", + show_default=True, +) +@parameter_override_option +@common_options +@aws_creds_options +@pass_context +@track_command +@check_newer_version +@print_cmdline_args +@command_exception_handler +def cli( + ctx, + template_file, + api_logical_id, + output_file, + format, + openapi_version, + parameter_overrides, + config_file, + config_env, +): + """ + `sam generate openapi` command entry point + """ + # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing + + do_cli( + template_file=template_file, + api_logical_id=api_logical_id, + output_file=output_file, + output_format=format, + openapi_version=openapi_version, + parameter_overrides=parameter_overrides, + region=ctx.region, + profile=ctx.profile, + ) # pragma: no cover + + +def do_cli( + template_file, + api_logical_id, + output_file, + output_format, + openapi_version, + parameter_overrides, + region, + profile, +): + """ + Implementation of the ``cli`` method + """ + from samcli.commands.generate.openapi.context import OpenApiContext + + with OpenApiContext( + template_file=template_file, + api_logical_id=api_logical_id, + output_file=output_file, + output_format=output_format, + openapi_version=openapi_version, + parameter_overrides=parameter_overrides, + region=region, + profile=profile, + ) as context: + context.run() diff --git a/samcli/commands/generate/openapi/context.py b/samcli/commands/generate/openapi/context.py new file mode 100644 index 0000000000..0ee2a032d2 --- /dev/null +++ b/samcli/commands/generate/openapi/context.py @@ -0,0 +1,178 @@ +""" +Context for OpenAPI generation command execution +""" + +import json +import logging +from typing import Dict, Optional, cast + +import click + +from samcli.commands.generate.openapi.exceptions import GenerateOpenApiException +from samcli.lib.generate.openapi_generator import OpenApiGenerator +from samcli.yamlhelper import yaml_dump + +LOG = logging.getLogger(__name__) + + +class OpenApiContext: + """ + Context manager for OpenAPI generation command + """ + + MSG_SUCCESS = "\nSuccessfully generated OpenAPI document{output_info}.\n" + + def __init__( + self, + template_file: str, + api_logical_id: Optional[str], + output_file: Optional[str], + output_format: str, + openapi_version: str, + parameter_overrides: Optional[Dict], + region: Optional[str], + profile: Optional[str], + ): + """ + Initialize OpenAPI generation context + + Parameters + ---------- + template_file : str + Path to SAM template + api_logical_id : str, optional + API resource logical ID + output_file : str, optional + Output file path (None for stdout) + output_format : str + Output format: 'yaml' or 'json' + openapi_version : str + OpenAPI version: '2.0' or '3.0' + parameter_overrides : dict, optional + Template parameter overrides + region : str, optional + AWS region + profile : str, optional + AWS profile + """ + self.template_file = template_file + self.api_logical_id = api_logical_id + self.output_file = output_file + self.output_format = output_format + self.openapi_version = openapi_version + self.parameter_overrides = parameter_overrides + self.region = region + self.profile = profile + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def run(self): + """ + Execute OpenAPI generation + + Raises + ------ + GenerateOpenApiException + If generation fails + """ + try: + LOG.debug( + "Generating OpenAPI - Template: %s, API ID: %s, Format: %s", + self.template_file, + self.api_logical_id or "auto-detect", + self.output_format, + ) + + # Create generator + generator = OpenApiGenerator( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + # Generate OpenAPI document + openapi_doc = generator.generate() + + # Convert to OpenAPI 3.0 if requested + if self.openapi_version == "3.0": + from samcli.lib.generate.openapi_converter import OpenApiConverter + + openapi_doc = OpenApiConverter.swagger_to_openapi3(openapi_doc) + + # Format output + output_str = self._format_output(openapi_doc) + + # Write output + self._write_output(output_str) + + # Display success message + self._display_success() + + except GenerateOpenApiException: + # Re-raise our specific exceptions + raise + except Exception as e: + # Wrap unexpected exceptions + raise GenerateOpenApiException(f"Unexpected error during OpenAPI generation: {str(e)}") from e + + def _format_output(self, openapi_doc: Dict) -> str: + """ + Format OpenAPI document as YAML or JSON + + Parameters + ---------- + openapi_doc : dict + OpenAPI document + + Returns + ------- + str + Formatted output string + """ + if self.output_format == "json": + return json.dumps(openapi_doc, indent=2, ensure_ascii=False) + else: + # Default to YAML + return cast(str, yaml_dump(openapi_doc)) + + def _write_output(self, content: str): + """ + Write output to file or stdout + + Parameters + ---------- + content : str + Content to write + """ + if self.output_file: + # Write to file + try: + with open(self.output_file, "w") as f: + f.write(content) + LOG.debug("Wrote OpenAPI document to file: %s", self.output_file) + except IOError as e: + raise GenerateOpenApiException(f"Failed to write to file '{self.output_file}': {str(e)}") from e + else: + # Write to stdout + click.echo(content) + + def _display_success(self): + """ + Display success message to user + """ + if self.output_file: + output_info = f" and wrote to file: {self.output_file}" + else: + output_info = "" + + msg = self.MSG_SUCCESS.format(output_info=output_info) + if self.output_file: + # Only show success message if writing to file + # (to not clutter stdout when piping) + click.secho(msg, fg="green") diff --git a/samcli/commands/generate/openapi/exceptions.py b/samcli/commands/generate/openapi/exceptions.py new file mode 100644 index 0000000000..d092d5978f --- /dev/null +++ b/samcli/commands/generate/openapi/exceptions.py @@ -0,0 +1,79 @@ +""" +Exceptions for generate openapi command +""" + +from samcli.commands.exceptions import UserException + + +class GenerateOpenApiException(UserException): + """Base exception for OpenAPI generation""" + + +class ApiResourceNotFoundException(GenerateOpenApiException): + """Raised when specified API resource not found""" + + fmt = "API resource '{api_id}' not found in template. {message}" + + def __init__(self, api_id, message=""): + self.api_id = api_id + self.message = message + msg = self.fmt.format(api_id=api_id, message=message) + super().__init__(message=msg) + + +class InvalidApiResourceException(GenerateOpenApiException): + """Raised when API resource is invalid""" + + fmt = "API resource '{api_id}' is not valid. {message}" + + def __init__(self, api_id, message=""): + self.api_id = api_id + self.message = message + msg = self.fmt.format(api_id=api_id, message=message) + super().__init__(message=msg) + + +class OpenApiExtractionException(GenerateOpenApiException): + """Raised when OpenAPI extraction fails""" + + fmt = "Failed to extract OpenAPI definition: {message}" + + def __init__(self, message=""): + self.message = message + msg = self.fmt.format(message=message) + super().__init__(message=msg) + + +class TemplateTransformationException(GenerateOpenApiException): + """Raised when SAM transformation fails""" + + fmt = "Failed to transform SAM template: {message}" + + def __init__(self, message=""): + self.message = message + msg = self.fmt.format(message=message) + super().__init__(message=msg) + + +class NoApiResourcesFoundException(GenerateOpenApiException): + """Raised when no API resources found in template""" + + fmt = "No API resources found in template. {message}" + + def __init__( + self, message="Please ensure your template contains AWS::Serverless::Api or AWS::Serverless::HttpApi resources." + ): + self.message = message + msg = self.fmt.format(message=message) + super().__init__(message=msg) + + +class MultipleApiResourcesException(GenerateOpenApiException): + """Raised when multiple API resources found and no logical ID specified""" + + fmt = "Multiple API resources found: {api_ids}. Please specify --api-logical-id." + + def __init__(self, api_ids): + self.api_ids = api_ids + msg = self.fmt.format(api_ids=", ".join(api_ids)) + super().__init__(message=msg) diff --git a/samcli/lib/generate/__init__.py b/samcli/lib/generate/__init__.py new file mode 100644 index 0000000000..2c9fdf1174 --- /dev/null +++ b/samcli/lib/generate/__init__.py @@ -0,0 +1,3 @@ +""" +Library for generate commands +""" diff --git a/samcli/lib/generate/openapi_converter.py b/samcli/lib/generate/openapi_converter.py new file mode 100644 index 0000000000..8005786c22 --- /dev/null +++ b/samcli/lib/generate/openapi_converter.py @@ -0,0 +1,55 @@ +""" +Convert Swagger 2.0 to OpenAPI 3.0 specification +""" + +import copy +from typing import Dict + + +class OpenApiConverter: + """Converts Swagger 2.0 specs to OpenAPI 3.0""" + + @staticmethod + def swagger_to_openapi3(swagger_doc: Dict) -> Dict: + """ + Convert Swagger 2.0 document to OpenAPI 3.0 + + Parameters + ---------- + swagger_doc : dict + Swagger 2.0 document + + Returns + ------- + dict + OpenAPI 3.0 document + """ + if not swagger_doc or not isinstance(swagger_doc, dict): + return swagger_doc + + # Check if already OpenAPI 3.0 + if "openapi" in swagger_doc: + return swagger_doc + + # Check if Swagger 2.0 + if "swagger" not in swagger_doc: + return swagger_doc + + # Create OpenAPI 3.0 document + openapi_doc = copy.deepcopy(swagger_doc) + + # 1. Change version + openapi_doc["openapi"] = "3.0.0" + del openapi_doc["swagger"] + + # 2. Move securityDefinitions to components.securitySchemes + if "securityDefinitions" in openapi_doc: + if "components" not in openapi_doc: + openapi_doc["components"] = {} + openapi_doc["components"]["securitySchemes"] = openapi_doc["securityDefinitions"] + del openapi_doc["securityDefinitions"] + + # 3. Keep x-amazon-apigateway extensions as is (API Gateway specific) + # These are AWS extensions that work in both formats + + return openapi_doc diff --git a/samcli/lib/generate/openapi_generator.py b/samcli/lib/generate/openapi_generator.py new file mode 100644 index 0000000000..941ad6552d --- /dev/null +++ b/samcli/lib/generate/openapi_generator.py @@ -0,0 +1,504 @@ +""" +OpenAPI Generator - Extracts OpenAPI specification from SAM templates +""" + +import logging +from typing import Dict, List, Optional, Tuple, cast + +from samcli.commands.generate.openapi.exceptions import ( + ApiResourceNotFoundException, + MultipleApiResourcesException, + NoApiResourcesFoundException, + OpenApiExtractionException, + TemplateTransformationException, +) +from samcli.commands.local.lib.swagger.reader import SwaggerReader +from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.lib.translate.sam_template_validator import SamTemplateValidator +from samcli.yamlhelper import yaml_parse + +LOG = logging.getLogger(__name__) + +# API resource types +SERVERLESS_API = "AWS::Serverless::Api" +SERVERLESS_HTTP_API = "AWS::Serverless::HttpApi" +API_GATEWAY_REST_API = "AWS::ApiGateway::RestApi" +API_GATEWAY_V2_API = "AWS::ApiGatewayV2::Api" + +SUPPORTED_API_TYPES = [SERVERLESS_API, SERVERLESS_HTTP_API, API_GATEWAY_REST_API, API_GATEWAY_V2_API] + + +class OpenApiGenerator: + """ + Generates OpenAPI specification from SAM template + """ + + def __init__( + self, + template_file: str, + api_logical_id: Optional[str] = None, + parameter_overrides: Optional[Dict] = None, + region: Optional[str] = None, + profile: Optional[str] = None, + ): + """ + Initialize OpenAPI generator + + Parameters + ---------- + template_file : str + Path to SAM template file + api_logical_id : str, optional + Logical ID of API resource to generate OpenAPI for + parameter_overrides : dict, optional + Template parameter overrides + region : str, optional + AWS region (for intrinsic functions) + profile : str, optional + AWS profile + """ + self.template_file = template_file + self.api_logical_id = api_logical_id + self.parameter_overrides = parameter_overrides or {} + self.region = region + self.profile = profile + + def generate(self) -> Dict: + """ + Main generation method - extracts OpenAPI from SAM template + + Returns + ------- + dict + OpenAPI document as dictionary + + Raises + ------ + NoApiResourcesFoundException + If no API resources found in template + ApiResourceNotFoundException + If specified API logical ID not found + MultipleApiResourcesException + If multiple APIs found and no logical ID specified + OpenApiExtractionException + If OpenAPI extraction fails + """ + LOG.debug("Starting OpenAPI generation from template: %s", self.template_file) + + # 1. Load and parse template + template = self._load_template() + + # 2. Find API resources + api_resources = self._find_api_resources(template) + + # 3. Check if template has implicit API (functions with API events) + has_implicit_api = self._has_implicit_api(template) + + if not api_resources and not has_implicit_api: + raise NoApiResourcesFoundException() + + # 4. If explicit API exists, use it + if api_resources: + target_api_id, target_api_resource = self._select_api_resource(api_resources) + LOG.debug("Generating OpenAPI for resource: %s (Type: %s)", target_api_id, target_api_resource.get("Type")) + + # Try to extract existing OpenAPI definition first + openapi_doc = self._extract_existing_definition(target_api_resource, target_api_id) + + if not openapi_doc: + LOG.debug("No existing OpenAPI definition found, transforming template") + openapi_doc = self._generate_from_transformation(template, target_api_id, target_api_resource) + else: + # Handle implicit API + LOG.debug("No explicit API resource found, checking for implicit ServerlessRestApi") + target_api_id = self.api_logical_id or "ServerlessRestApi" + openapi_doc = self._generate_implicit_api(template, target_api_id) + + # 5. Validate OpenAPI structure + if not self._validate_openapi(openapi_doc): + raise OpenApiExtractionException("Generated OpenAPI document is invalid or empty") + + LOG.debug("Successfully generated OpenAPI document") + return openapi_doc + + def _load_template(self) -> Dict: + """ + Load and parse SAM template + + Returns + ------- + dict + Parsed template dictionary + + Raises + ------ + OpenApiExtractionException + If template cannot be loaded + """ + try: + with open(self.template_file, "r") as f: + template = yaml_parse(f.read()) + + if not template or not isinstance(template, dict): + raise OpenApiExtractionException("Template file is empty or invalid") + + if "Resources" not in template: + raise OpenApiExtractionException("Template does not contain 'Resources' section") + + return template + + except FileNotFoundError as e: + raise OpenApiExtractionException(f"Template file not found: {self.template_file}") from e + except Exception as e: + raise OpenApiExtractionException(f"Failed to load template: {str(e)}") from e + + def _find_api_resources(self, template: Dict) -> Dict[str, Dict]: + """ + Find all API resources in template + + Parameters + ---------- + template : dict + SAM template dictionary + + Returns + ------- + dict + Dictionary of API resources {logical_id: resource_dict} + """ + api_resources = {} + resources = template.get("Resources", {}) + + for logical_id, resource in resources.items(): + resource_type = resource.get("Type") + if resource_type in SUPPORTED_API_TYPES: + api_resources[logical_id] = resource + + LOG.debug("Found %d API resources: %s", len(api_resources), list(api_resources.keys())) + return api_resources + + def _select_api_resource(self, api_resources: Dict[str, Dict]) -> Tuple[str, Dict]: + """ + Select which API resource to generate OpenAPI for + + Parameters + ---------- + api_resources : dict + Dictionary of API resources + + Returns + ------- + tuple + (logical_id, resource_dict) + + Raises + ------ + ApiResourceNotFoundException + If specified API not found + MultipleApiResourcesException + If multiple APIs found and none specified + """ + if self.api_logical_id: + # User specified an API logical ID + if self.api_logical_id not in api_resources: + available_apis = ", ".join(api_resources.keys()) + raise ApiResourceNotFoundException( + self.api_logical_id, f"Available APIs: {available_apis}" if available_apis else "No APIs found" + ) + return self.api_logical_id, api_resources[self.api_logical_id] + + # No API specified - check if there's only one + if len(api_resources) == 1: + logical_id = list(api_resources.keys())[0] + return logical_id, api_resources[logical_id] + + # Multiple APIs and none specified + raise MultipleApiResourcesException(list(api_resources.keys())) + + def _extract_existing_definition(self, resource: Dict, logical_id: str) -> Optional[Dict]: + """ + Extract OpenAPI if already defined in DefinitionBody or DefinitionUri + + Parameters + ---------- + resource : dict + API resource dictionary + logical_id : str + Logical ID of the resource + + Returns + ------- + dict or None + OpenAPI document if found, None otherwise + """ + properties = resource.get("Properties", {}) + definition_body = properties.get("DefinitionBody") + definition_uri = properties.get("DefinitionUri") + + # For ApiGateway resources, check Body and BodyS3Location + if not definition_body: + definition_body = properties.get("Body") + if not definition_uri: + definition_uri = properties.get("BodyS3Location") + + if not definition_body and not definition_uri: + LOG.debug("No DefinitionBody or DefinitionUri found in resource %s", logical_id) + return None + + try: + # Use SwaggerReader to handle various definition sources + reader = SwaggerReader(definition_body=definition_body, definition_uri=definition_uri, working_dir=".") + openapi_doc = reader.read() + + if openapi_doc: + LOG.debug("Successfully extracted existing OpenAPI definition from resource %s", logical_id) + return cast(Dict, openapi_doc) + + except Exception as e: + LOG.debug("Failed to read existing definition: %s", str(e)) + + return None + + def _generate_from_transformation(self, template: Dict, logical_id: str, resource: Dict) -> Dict: + """ + Generate OpenAPI by transforming SAM template to CloudFormation + + Parameters + ---------- + template : dict + SAM template + logical_id : str + API logical ID + resource : dict + API resource + + Returns + ------- + dict + Generated OpenAPI document + + Raises + ------ + TemplateTransformationException + If transformation fails + OpenApiExtractionException + If OpenAPI extraction from transformed template fails + """ + try: + # Transform template using SAM Translator + validator = SamTemplateValidator( + sam_template=template, + managed_policy_loader=None, + profile=self.profile, + region=self.region, + parameter_overrides=self.parameter_overrides, + ) + + # Get transformed CloudFormation template + transformed_str = validator.get_translated_template_if_valid() + + # Parse transformed template + transformed_template = yaml_parse(transformed_str) + + # Extract OpenAPI from transformed resource + openapi_doc = self._extract_from_cfn_template(transformed_template, logical_id, resource) + + return openapi_doc + + except InvalidSamDocumentException as e: + raise TemplateTransformationException(str(e)) from e + except Exception as e: + raise TemplateTransformationException(f"Unexpected error during transformation: {str(e)}") from e + + def _extract_from_cfn_template(self, cfn_template: Dict, original_logical_id: str, original_resource: Dict) -> Dict: + """ + Extract OpenAPI definition from transformed CloudFormation template + + Parameters + ---------- + cfn_template : dict + Transformed CloudFormation template + original_logical_id : str + Original SAM resource logical ID + original_resource : dict + Original SAM resource + + Returns + ------- + dict + OpenAPI document + + Raises + ------ + OpenApiExtractionException + If OpenAPI cannot be extracted + """ + resources = cfn_template.get("Resources", {}) + + # The transformed resource might have the same or different logical ID + # For ServerlessRestApi created implicitly, SAM generates it with that name + possible_ids = [original_logical_id, "ServerlessRestApi", "ServerlessHttpApi"] + + for resource_id in possible_ids: + if resource_id in resources: + resource = resources[resource_id] + resource_type = resource.get("Type") + + # Check if it's an API Gateway resource + if resource_type in [API_GATEWAY_REST_API, API_GATEWAY_V2_API]: + properties = resource.get("Properties", {}) + definition_body = properties.get("Body") or properties.get("DefinitionBody") + + if definition_body and isinstance(definition_body, dict): + LOG.debug("Extracted OpenAPI from transformed resource %s", resource_id) + return cast(Dict, definition_body) + + # If we couldn't find it in transformed resources, try the original approach + raise OpenApiExtractionException( + f"Could not extract OpenAPI definition from transformed template for resource '{original_logical_id}'. " + "The resource may not generate an OpenAPI document or transformation failed." + ) + + def _validate_openapi(self, openapi_doc: Dict) -> bool: + """ + Validate OpenAPI document structure + + Parameters + ---------- + openapi_doc : dict + OpenAPI document + + Returns + ------- + bool + True if valid, False otherwise + """ + if not openapi_doc or not isinstance(openapi_doc, dict): + return False + + # Check for required OpenAPI fields + has_swagger = "swagger" in openapi_doc + has_openapi = "openapi" in openapi_doc + has_paths = "paths" in openapi_doc + + if not (has_swagger or has_openapi): + LOG.warning("OpenAPI document missing 'swagger' or 'openapi' version field") + return False + + if not has_paths: + LOG.warning("OpenAPI document missing 'paths' field") + return False + + return True + + def _has_implicit_api(self, template: Dict) -> bool: + """ + Check if template has implicit API (functions with API events) + + Parameters + ---------- + template : dict + SAM template + + Returns + ------- + bool + True if template has functions with API events + """ + resources = template.get("Resources", {}) + + for resource in resources.values(): + if resource.get("Type") == "AWS::Serverless::Function": + events = resource.get("Properties", {}).get("Events", {}) + for event in events.values(): + event_type = event.get("Type", "") + if event_type in ["Api", "HttpApi"]: + return True + + return False + + def _generate_implicit_api(self, template: Dict, api_id: str) -> Dict: + """ + Generate OpenAPI for implicit API by transforming template + + Parameters + ---------- + template : dict + SAM template + api_id : str + API logical ID (e.g., ServerlessRestApi) + + Returns + ------- + dict + Generated OpenAPI document + + Raises + ------ + TemplateTransformationException + If transformation fails + OpenApiExtractionException + If OpenAPI extraction fails + """ + try: + # Transform template using SAM Translator + validator = SamTemplateValidator( + sam_template=template, + managed_policy_loader=None, + profile=self.profile, + region=self.region, + parameter_overrides=self.parameter_overrides, + ) + + # Get transformed CloudFormation template + transformed_str = validator.get_translated_template_if_valid() + + # Parse transformed template + transformed_template = yaml_parse(transformed_str) + + # Extract OpenAPI from transformed implicit API + resources = transformed_template.get("Resources", {}) + + if api_id in resources: + resource = resources[api_id] + resource_type = resource.get("Type") + + if resource_type in [API_GATEWAY_REST_API, API_GATEWAY_V2_API]: + properties = resource.get("Properties", {}) + definition_body = properties.get("Body") or properties.get("DefinitionBody") + + if definition_body and isinstance(definition_body, dict): + LOG.debug("Extracted OpenAPI from implicit API %s", api_id) + return cast(Dict, definition_body) + + raise OpenApiExtractionException( + f"Could not extract OpenAPI definition for implicit API '{api_id}'. " + "The template may not generate an implicit API or transformation failed." + ) + + except InvalidSamDocumentException as e: + raise TemplateTransformationException(str(e)) from e + except OpenApiExtractionException: + raise + except Exception as e: + raise TemplateTransformationException(f"Unexpected error during transformation: {str(e)}") from e + + def get_api_resources_info(self) -> List[Dict[str, str]]: + """ + Get information about API resources in the template (useful for CLI output) + + Returns + ------- + list + List of dicts with API resource information + """ + try: + template = self._load_template() + api_resources = self._find_api_resources(template) + + return [ + {"LogicalId": logical_id, "Type": resource.get("Type", "Unknown")} + for logical_id, resource in api_resources.items() + ] + except Exception: + return [] diff --git a/schema/samcli.json b/schema/samcli.json index 5203d60bc7..32616d06aa 100644 --- a/schema/samcli.json +++ b/schema/samcli.json @@ -1025,6 +1025,89 @@ "parameters" ] }, + "generate_openapi": { + "title": "Generate Openapi command", + "description": "Generate an OpenAPI (Swagger) specification document from a SAM template.\n \n SAM automatically generates OpenAPI documents for your APIs at deploy time. \n This command allows you to access that generated OpenAPI document as part of \n your build process, enabling integration with tools like swagger-codegen, \n OpenAPI Generator, and other API documentation/client generation tools.", + "properties": { + "parameters": { + "title": "Parameters for the generate openapi command", + "description": "Available parameters for the generate openapi command:\n* template_file:\nAWS SAM template file.\n* api_logical_id:\nLogical ID of the API resource to generate OpenAPI for. Required when template contains multiple APIs. Defaults to auto-detection when only one API exists.\n* output_file:\nPath to output file for generated OpenAPI document. If not specified, outputs to stdout.\n* format:\nOutput format for the OpenAPI document.\n* openapi_version:\nOpenAPI specification version (2.0 = Swagger, 3.0 = OpenAPI).\n* parameter_overrides:\nString that contains AWS CloudFormation parameter overrides encoded as key=value pairs.\n* beta_features:\nEnable/Disable beta features.\n* debug:\nTurn on debug logging to print debug message generated by AWS SAM CLI and display timestamps.\n* profile:\nSelect a specific profile from your credential file to get AWS credentials.\n* region:\nSet the AWS Region of the service. (e.g. us-east-1)", + "type": "object", + "properties": { + "template_file": { + "title": "template_file", + "type": "string", + "description": "AWS SAM template file.", + "default": "template.[yaml|yml|json]" + }, + "api_logical_id": { + "title": "api_logical_id", + "type": "string", + "description": "Logical ID of the API resource to generate OpenAPI for. Required when template contains multiple APIs. Defaults to auto-detection when only one API exists." + }, + "output_file": { + "title": "output_file", + "type": "string", + "description": "Path to output file for generated OpenAPI document. If not specified, outputs to stdout." + }, + "format": { + "title": "format", + "type": "string", + "description": "Output format for the OpenAPI document.", + "default": "yaml", + "enum": [ + "json", + "yaml" + ] + }, + "openapi_version": { + "title": "openapi_version", + "type": "string", + "description": "OpenAPI specification version (2.0 = Swagger, 3.0 = OpenAPI).", + "default": "3.0", + "enum": [ + "2.0", + "3.0" + ] + }, + "parameter_overrides": { + "title": "parameter_overrides", + "type": [ + "array", + "string" + ], + "description": "String that contains AWS CloudFormation parameter overrides encoded as key=value pairs.", + "items": { + "type": "string" + } + }, + "beta_features": { + "title": "beta_features", + "type": "boolean", + "description": "Enable/Disable beta features." + }, + "debug": { + "title": "debug", + "type": "boolean", + "description": "Turn on debug logging to print debug message generated by AWS SAM CLI and display timestamps." + }, + "profile": { + "title": "profile", + "type": "string", + "description": "Select a specific profile from your credential file to get AWS credentials." + }, + "region": { + "title": "region", + "type": "string", + "description": "Set the AWS Region of the service. (e.g. us-east-1)" + } + } + } + }, + "required": [ + "parameters" + ] + }, "package": { "title": "Package command", "description": "Package an AWS SAM application.", diff --git a/tests/integration/generate/__init__.py b/tests/integration/generate/__init__.py new file mode 100644 index 0000000000..d46bb6404b --- /dev/null +++ b/tests/integration/generate/__init__.py @@ -0,0 +1,3 @@ +""" +Integration tests for generate commands +""" diff --git a/tests/integration/generate/openapi/__init__.py b/tests/integration/generate/openapi/__init__.py new file mode 100644 index 0000000000..c44abd6a4e --- /dev/null +++ b/tests/integration/generate/openapi/__init__.py @@ -0,0 +1,3 @@ +""" +Integration tests for openapi command +""" diff --git a/tests/integration/generate/openapi/generate_openapi_integ_base.py b/tests/integration/generate/openapi/generate_openapi_integ_base.py new file mode 100644 index 0000000000..2fc4cd5ac6 --- /dev/null +++ b/tests/integration/generate/openapi/generate_openapi_integ_base.py @@ -0,0 +1,51 @@ +""" +Base class for generate openapi integration tests +""" + +import os +import uuid +import shutil +import tempfile +from pathlib import Path +from unittest import TestCase +from tests.testing_utils import get_sam_command, run_command + + +class GenerateOpenApiIntegBase(TestCase): + template = "template.yaml" + + @classmethod + def setUpClass(cls): + cls.cmd = get_sam_command() + integration_dir = Path(__file__).resolve().parents[1] + cls.test_data_path = str(Path(integration_dir, "testdata", "generate", "openapi")) + + def setUp(self): + self.scratch_dir = str(Path(__file__).resolve().parent.joinpath("tmp", str(uuid.uuid4()).replace("-", "")[:10])) + shutil.rmtree(self.scratch_dir, ignore_errors=True) + os.makedirs(self.scratch_dir) + + self.working_dir = tempfile.mkdtemp(dir=self.scratch_dir) + self.output_file_path = Path(self.working_dir, "openapi.yaml") + + def tearDown(self): + self.working_dir and shutil.rmtree(self.working_dir, ignore_errors=True) + self.scratch_dir and shutil.rmtree(self.scratch_dir, ignore_errors=True) + + def get_command_list(self, template_path, api_logical_id=None, output_file=None, format="yaml"): + """Build command list for generate openapi""" + command_list = [self.cmd, "generate", "openapi"] + + if template_path: + command_list += ["-t", template_path] + + if api_logical_id: + command_list += ["--api-logical-id", api_logical_id] + + if output_file: + command_list += ["-o", output_file] + + if format: + command_list += ["--format", format] + + return command_list diff --git a/tests/integration/generate/openapi/test_generate_openapi_command.py b/tests/integration/generate/openapi/test_generate_openapi_command.py new file mode 100644 index 0000000000..33cee72b47 --- /dev/null +++ b/tests/integration/generate/openapi/test_generate_openapi_command.py @@ -0,0 +1,98 @@ +""" +Integration tests for sam generate openapi command +""" + +import os +import json +from pathlib import Path +from tests.integration.generate.openapi.generate_openapi_integ_base import GenerateOpenApiIntegBase +from tests.testing_utils import run_command +from samcli.yamlhelper import yaml_parse + + +class TestGenerateOpenApiCommand(GenerateOpenApiIntegBase): + """Integration tests for generate openapi command""" + + template = "simple_api.yaml" + + def test_generate_openapi_to_stdout(self): + """Test generating OpenAPI to stdout""" + template_path = str(Path(self.test_data_path, self.template)) + command_list = self.get_command_list(template_path=template_path) + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertEqual(command_result.process.returncode, 0) + stdout = command_result.stdout.decode("utf-8") + + # Verify OpenAPI structure + openapi_doc = yaml_parse(stdout) + self.assertIn("swagger", openapi_doc) + self.assertIn("paths", openapi_doc) + self.assertIn("/hello", openapi_doc["paths"]) + + def test_generate_openapi_to_file(self): + """Test generating OpenAPI to file""" + template_path = str(Path(self.test_data_path, self.template)) + output_file = str(self.output_file_path) + command_list = self.get_command_list(template_path=template_path, output_file=output_file) + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertEqual(command_result.process.returncode, 0) + self.assertTrue(self.output_file_path.exists()) + + # Verify file contents + with open(output_file, "r") as f: + openapi_doc = yaml_parse(f.read()) + self.assertIn("swagger", openapi_doc) + self.assertIn("paths", openapi_doc) + + def test_generate_openapi_json_format(self): + """Test generating OpenAPI in JSON format""" + template_path = str(Path(self.test_data_path, self.template)) + command_list = self.get_command_list(template_path=template_path, format="json") + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertEqual(command_result.process.returncode, 0) + stdout = command_result.stdout.decode("utf-8") + + # Verify JSON format + openapi_doc = json.loads(stdout) + self.assertIn("paths", openapi_doc) + + def test_generate_openapi_explicit_api(self): + """Test generating OpenAPI from explicit API resource""" + template = "explicit_api.yaml" + template_path = str(Path(self.test_data_path, template)) + command_list = self.get_command_list(template_path=template_path, api_logical_id="MyApi") + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertEqual(command_result.process.returncode, 0) + stdout = command_result.stdout.decode("utf-8") + openapi_doc = yaml_parse(stdout) + self.assertIn("swagger", openapi_doc) + + def test_generate_openapi_no_api_error(self): + """Test error when no API resources found""" + template = "no_api.yaml" + template_path = str(Path(self.test_data_path, template)) + command_list = self.get_command_list(template_path=template_path) + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertNotEqual(command_result.process.returncode, 0) + stderr = command_result.stderr.decode("utf-8") + self.assertIn("No API resources found", stderr) + + def test_generate_command_group(self): + """Test that generate command group exists""" + command_list = [self.cmd, "generate", "--help"] + + command_result = run_command(command_list, cwd=self.working_dir) + + self.assertEqual(command_result.process.returncode, 0) + stdout = command_result.stdout.decode("utf-8") + self.assertIn("openapi", stdout) diff --git a/tests/integration/testdata/generate/openapi/explicit_api.yaml b/tests/integration/testdata/generate/openapi/explicit_api.yaml new file mode 100644 index 0000000000..4abc4830cd --- /dev/null +++ b/tests/integration/testdata/generate/openapi/explicit_api.yaml @@ -0,0 +1,29 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Description: Template with explicit API resource + +Resources: + MyApi: + Type: AWS::Serverless::Api + Properties: + StageName: Prod + DefinitionBody: + swagger: "2.0" + info: + title: My API + version: "1.0" + paths: + /hello: + get: + x-amazon-apigateway-integration: + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyFunction.Arn}/invocations + httpMethod: POST + + MyFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: hello_world/ + Handler: app.lambda_handler + Runtime: python3.9 diff --git a/tests/integration/testdata/generate/openapi/no_api.yaml b/tests/integration/testdata/generate/openapi/no_api.yaml new file mode 100644 index 0000000000..b0d1ea87b6 --- /dev/null +++ b/tests/integration/testdata/generate/openapi/no_api.yaml @@ -0,0 +1,11 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Description: Template without API resources + +Resources: + MyFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: hello_world/ + Handler: app.lambda_handler + Runtime: python3.9 diff --git a/tests/integration/testdata/generate/openapi/simple_api.yaml b/tests/integration/testdata/generate/openapi/simple_api.yaml new file mode 100644 index 0000000000..70d28f21c0 --- /dev/null +++ b/tests/integration/testdata/generate/openapi/simple_api.yaml @@ -0,0 +1,17 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Description: Simple API template for testing + +Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: hello_world/ + Handler: app.lambda_handler + Runtime: python3.9 + Events: + HelloWorld: + Type: Api + Properties: + Path: /hello + Method: get diff --git a/tests/unit/cli/test_command.py b/tests/unit/cli/test_command.py index cd7f052c90..e354ae8c21 100644 --- a/tests/unit/cli/test_command.py +++ b/tests/unit/cli/test_command.py @@ -174,6 +174,7 @@ def test_get_command_root_command_text(self): ("build", "build command output"), ("local", "local command output"), ("validate", "validate command output"), + ("generate", "generate command output"), ("sync", "sync command output"), ("remote", "remote command output"), ], @@ -193,6 +194,7 @@ def test_get_command_root_command_text(self): "build": "build command output", "local": "local command output", "validate": "validate command output", + "generate": "generate command output", "sync": "sync command output", "remote": "remote command output", "package": "package command output", diff --git a/tests/unit/commands/generate/__init__.py b/tests/unit/commands/generate/__init__.py new file mode 100644 index 0000000000..00b1493420 --- /dev/null +++ b/tests/unit/commands/generate/__init__.py @@ -0,0 +1 @@ +"""Unit tests for generate commands""" diff --git a/tests/unit/commands/generate/openapi/__init__.py b/tests/unit/commands/generate/openapi/__init__.py new file mode 100644 index 0000000000..9959adf145 --- /dev/null +++ b/tests/unit/commands/generate/openapi/__init__.py @@ -0,0 +1 @@ +"""Unit tests for generate openapi command""" diff --git a/tests/unit/commands/generate/openapi/test_command.py b/tests/unit/commands/generate/openapi/test_command.py new file mode 100644 index 0000000000..adc26493f5 --- /dev/null +++ b/tests/unit/commands/generate/openapi/test_command.py @@ -0,0 +1,213 @@ +"""Unit tests for generate openapi command entry point""" + +from unittest import TestCase +from unittest.mock import Mock, patch, call + +from samcli.commands.generate.openapi.command import do_cli +from samcli.commands.generate.openapi.exceptions import GenerateOpenApiException + + +class TestGenerateOpenApiCommand(TestCase): + """Test generate openapi command entry point""" + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_successful(self, mock_context_class): + """Test successful command execution""" + # Setup + template_file = "template.yaml" + api_logical_id = "MyApi" + output_file = "output.yaml" + output_format = "yaml" + parameter_overrides = {"Stage": "prod"} + region = "us-east-1" + profile = "default" + + # Create mock context instance + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute + do_cli( + template_file=template_file, + api_logical_id=api_logical_id, + output_file=output_file, + output_format=output_format, + openapi_version="3.0", + parameter_overrides=parameter_overrides, + region=region, + profile=profile, + ) + + # Verify context was created with correct parameters + mock_context_class.assert_called_once_with( + template_file=template_file, + api_logical_id=api_logical_id, + output_file=output_file, + output_format=output_format, + openapi_version="3.0", + parameter_overrides=parameter_overrides, + region=region, + profile=profile, + ) + + # Verify run was called + mock_context.run.assert_called_once() + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_with_minimal_parameters(self, mock_context_class): + """Test command with only required parameters""" + # Setup + template_file = "template.yaml" + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Verify context was created + mock_context_class.assert_called_once_with( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Verify run was called + mock_context.run.assert_called_once() + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_json_format(self, mock_context_class): + """Test command with JSON output format""" + # Setup + template_file = "template.yaml" + output_format = "json" + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format=output_format, + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Verify format was passed correctly + self.assertEqual(mock_context_class.call_args[1]["output_format"], "json") + mock_context.run.assert_called_once() + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_with_output_file(self, mock_context_class): + """Test command writing to output file""" + # Setup + template_file = "template.yaml" + output_file = "api-spec.yaml" + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=output_file, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Verify output file was passed + self.assertEqual(mock_context_class.call_args[1]["output_file"], output_file) + mock_context.run.assert_called_once() + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_with_parameter_overrides(self, mock_context_class): + """Test command with parameter overrides""" + # Setup + template_file = "template.yaml" + parameter_overrides = {"Stage": "prod", "Environment": "production"} + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=parameter_overrides, + region=None, + profile=None, + ) + + # Verify parameter overrides were passed + self.assertEqual(mock_context_class.call_args[1]["parameter_overrides"], parameter_overrides) + mock_context.run.assert_called_once() + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_propagates_exceptions(self, mock_context_class): + """Test that exceptions from context are propagated""" + # Setup + template_file = "template.yaml" + mock_context = Mock() + mock_context.run.side_effect = GenerateOpenApiException("Test error") + mock_context_class.return_value.__enter__.return_value = mock_context + + # Execute and verify exception is raised + with self.assertRaises(GenerateOpenApiException) as ex: + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + self.assertIn("Test error", str(ex.exception)) + + @patch("samcli.commands.generate.openapi.context.OpenApiContext") + def test_do_cli_context_manager_cleanup(self, mock_context_class): + """Test that context manager properly enters and exits""" + # Setup + template_file = "template.yaml" + mock_context = Mock() + mock_context_class.return_value.__enter__.return_value = mock_context + mock_context_class.return_value.__exit__ = Mock() + + # Execute + do_cli( + template_file=template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Verify context manager was used properly + mock_context_class.return_value.__enter__.assert_called_once() + mock_context_class.return_value.__exit__.assert_called_once() diff --git a/tests/unit/commands/generate/openapi/test_context.py b/tests/unit/commands/generate/openapi/test_context.py new file mode 100644 index 0000000000..8e45d6d504 --- /dev/null +++ b/tests/unit/commands/generate/openapi/test_context.py @@ -0,0 +1,423 @@ +"""Unit tests for OpenAPI generation context""" + +import json +from io import StringIO +from unittest import TestCase +from unittest.mock import Mock, patch, mock_open, MagicMock + +from samcli.commands.generate.openapi.context import OpenApiContext +from samcli.commands.generate.openapi.exceptions import GenerateOpenApiException + + +class TestOpenApiContext(TestCase): + """Test OpenApiContext class""" + + def setUp(self): + """Set up test fixtures""" + self.template_file = "template.yaml" + self.api_logical_id = "MyApi" + self.output_file = "output.yaml" + self.output_format = "yaml" + self.parameter_overrides = {"Stage": "prod"} + self.region = "us-east-1" + self.profile = "default" + + def test_initialization(self): + """Test context initialization""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + output_file=self.output_file, + output_format=self.output_format, + openapi_version="3.0", + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + self.assertEqual(context.template_file, self.template_file) + self.assertEqual(context.api_logical_id, self.api_logical_id) + self.assertEqual(context.output_file, self.output_file) + self.assertEqual(context.output_format, self.output_format) + self.assertEqual(context.parameter_overrides, self.parameter_overrides) + self.assertEqual(context.region, self.region) + self.assertEqual(context.profile, self.profile) + + def test_context_manager_enter(self): + """Test context manager __enter__ returns self""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + with context as ctx: + self.assertIs(ctx, context) + + def test_context_manager_exit(self): + """Test context manager __exit__ completes without error""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Should not raise any exceptions + with context: + pass + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + @patch("samcli.commands.generate.openapi.context.click") + def test_run_successful_yaml_stdout(self, mock_click, mock_generator_class): + """Test successful run with YAML output to stdout""" + # Setup mock generator + mock_generator = Mock() + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test API"}} + mock_generator.generate.return_value = openapi_doc + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + output_file=None, # stdout + output_format="yaml", + openapi_version="3.0", + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + # Execute + context.run() + + # Verify generator was created correctly + mock_generator_class.assert_called_once_with( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + # Verify generate was called + mock_generator.generate.assert_called_once() + + # Verify output to stdout (click.echo called) + mock_click.echo.assert_called_once() + output = mock_click.echo.call_args[0][0] + self.assertIn("openapi", output) + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + @patch("samcli.commands.generate.openapi.context.click") + @patch("builtins.open", new_callable=mock_open) + def test_run_successful_yaml_file(self, mock_file, mock_click, mock_generator_class): + """Test successful run with YAML output to file""" + # Setup mock generator + mock_generator = Mock() + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test API"}} + mock_generator.generate.return_value = openapi_doc + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + output_file="api.yaml", + output_format="yaml", + openapi_version="3.0", + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + # Execute + context.run() + + # Verify file was opened for writing + mock_file.assert_called_once_with("api.yaml", "w") + + # Verify content was written + handle = mock_file() + handle.write.assert_called() + + # Verify success message shown + mock_click.secho.assert_called() + success_msg = mock_click.secho.call_args[0][0] + self.assertIn("Successfully generated", success_msg) + self.assertIn("api.yaml", success_msg) + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + @patch("samcli.commands.generate.openapi.context.click") + def test_run_successful_json_stdout(self, mock_click, mock_generator_class): + """Test successful run with JSON output to stdout""" + # Setup mock generator + mock_generator = Mock() + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test API"}} + mock_generator.generate.return_value = openapi_doc + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + output_file=None, + output_format="json", + openapi_version="3.0", + parameter_overrides=self.parameter_overrides, + region=self.region, + profile=self.profile, + ) + + # Execute + context.run() + + # Verify output is JSON + mock_click.echo.assert_called_once() + output = mock_click.echo.call_args[0][0] + # Should be valid JSON + parsed = json.loads(output) + self.assertEqual(parsed["openapi"], "3.0.0") + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + def test_run_generator_exception(self, mock_generator_class): + """Test run propagates GenerateOpenApiException""" + # Setup mock generator that raises exception + mock_generator = Mock() + mock_generator.generate.side_effect = GenerateOpenApiException("API not found") + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=self.api_logical_id, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Execute and verify exception is propagated + with self.assertRaises(GenerateOpenApiException) as ex: + context.run() + + self.assertIn("API not found", str(ex.exception)) + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + def test_run_unexpected_exception(self, mock_generator_class): + """Test run wraps unexpected exceptions""" + # Setup mock generator that raises unexpected exception + mock_generator = Mock() + mock_generator.generate.side_effect = RuntimeError("Unexpected error") + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Execute and verify exception is wrapped + with self.assertRaises(GenerateOpenApiException) as ex: + context.run() + + self.assertIn("Unexpected error", str(ex.exception)) + + def test_format_output_yaml(self): + """Test YAML output formatting""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test API"}} + output = context._format_output(openapi_doc) + + # Should be YAML format + self.assertIsInstance(output, str) + self.assertIn("openapi:", output) + self.assertIn("3.0.0", output) + + def test_format_output_json(self): + """Test JSON output formatting""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="json", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test API"}} + output = context._format_output(openapi_doc) + + # Should be valid JSON + self.assertIsInstance(output, str) + parsed = json.loads(output) + self.assertEqual(parsed["openapi"], "3.0.0") + self.assertEqual(parsed["info"]["title"], "Test API") + + @patch("builtins.open", new_callable=mock_open) + def test_write_output_to_file(self, mock_file): + """Test writing output to file""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file="output.yaml", + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + content = "openapi: 3.0.0\n" + context._write_output(content) + + # Verify file operations + mock_file.assert_called_once_with("output.yaml", "w") + handle = mock_file() + handle.write.assert_called_once_with(content) + + @patch("samcli.commands.generate.openapi.context.click") + def test_write_output_to_stdout(self, mock_click): + """Test writing output to stdout""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + content = "openapi: 3.0.0\n" + context._write_output(content) + + # Verify click.echo was called + mock_click.echo.assert_called_once_with(content) + + @patch("builtins.open", side_effect=IOError("Permission denied")) + def test_write_output_file_error(self, mock_file): + """Test file write error handling""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file="output.yaml", + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + content = "openapi: 3.0.0\n" + + # Verify exception is raised + with self.assertRaises(GenerateOpenApiException) as ex: + context._write_output(content) + + self.assertIn("Failed to write to file", str(ex.exception)) + self.assertIn("output.yaml", str(ex.exception)) + + @patch("samcli.commands.generate.openapi.context.click") + def test_display_success_with_file(self, mock_click): + """Test success message when writing to file""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file="api.yaml", + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + context._display_success() + + # Verify success message shown with file info + mock_click.secho.assert_called_once() + message = mock_click.secho.call_args[0][0] + self.assertIn("Successfully generated", message) + self.assertIn("api.yaml", message) + self.assertEqual(mock_click.secho.call_args[1]["fg"], "green") + + @patch("samcli.commands.generate.openapi.context.click") + def test_display_success_without_file(self, mock_click): + """Test success message when writing to stdout""" + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + context._display_success() + + # Should not display message when output is stdout + # (to avoid cluttering piped output) + mock_click.secho.assert_not_called() + + @patch("samcli.commands.generate.openapi.context.OpenApiGenerator") + @patch("samcli.commands.generate.openapi.context.click") + def test_run_with_none_parameters(self, mock_click, mock_generator_class): + """Test run with all optional parameters as None""" + # Setup mock generator + mock_generator = Mock() + openapi_doc = {"openapi": "3.0.0"} + mock_generator.generate.return_value = openapi_doc + mock_generator_class.return_value = mock_generator + + context = OpenApiContext( + template_file=self.template_file, + api_logical_id=None, + output_file=None, + output_format="yaml", + openapi_version="3.0", + parameter_overrides=None, + region=None, + profile=None, + ) + + # Should execute without errors + context.run() + + # Verify generator was created with None values + mock_generator_class.assert_called_once_with( + template_file=self.template_file, + api_logical_id=None, + parameter_overrides=None, + region=None, + profile=None, + ) + + mock_generator.generate.assert_called_once() diff --git a/tests/unit/commands/generate/openapi/test_exceptions.py b/tests/unit/commands/generate/openapi/test_exceptions.py new file mode 100644 index 0000000000..d4a68e185d --- /dev/null +++ b/tests/unit/commands/generate/openapi/test_exceptions.py @@ -0,0 +1,216 @@ +"""Unit tests for OpenAPI generation exceptions""" + +from unittest import TestCase + +from samcli.commands.exceptions import UserException +from samcli.commands.generate.openapi.exceptions import ( + GenerateOpenApiException, + ApiResourceNotFoundException, + InvalidApiResourceException, + OpenApiExtractionException, + TemplateTransformationException, + NoApiResourcesFoundException, + MultipleApiResourcesException, +) + + +class TestGenerateOpenApiException(TestCase): + """Test base GenerateOpenApiException""" + + def test_is_user_exception(self): + """Test that GenerateOpenApiException inherits from UserException""" + exception = GenerateOpenApiException("Test error") + self.assertIsInstance(exception, UserException) + + def test_exception_message(self): + """Test exception with custom message""" + message = "Test error message" + exception = GenerateOpenApiException(message) + self.assertEqual(str(exception), message) + + +class TestApiResourceNotFoundException(TestCase): + """Test ApiResourceNotFoundException""" + + def test_exception_with_api_id_only(self): + """Test exception with only API ID""" + api_id = "MyApi" + exception = ApiResourceNotFoundException(api_id) + + self.assertEqual(exception.api_id, api_id) + self.assertIn("MyApi", str(exception)) + self.assertIn("not found", str(exception)) + + def test_exception_with_api_id_and_message(self): + """Test exception with API ID and custom message""" + api_id = "MyApi" + message = "Check your template configuration" + exception = ApiResourceNotFoundException(api_id, message) + + self.assertEqual(exception.api_id, api_id) + self.assertIn("MyApi", str(exception)) + self.assertIn("Check your template configuration", str(exception)) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = ApiResourceNotFoundException("MyApi") + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestInvalidApiResourceException(TestCase): + """Test InvalidApiResourceException""" + + def test_exception_with_api_id_only(self): + """Test exception with only API ID""" + api_id = "MyApi" + exception = InvalidApiResourceException(api_id) + + self.assertEqual(exception.api_id, api_id) + self.assertIn("MyApi", str(exception)) + self.assertIn("not valid", str(exception)) + + def test_exception_with_api_id_and_message(self): + """Test exception with API ID and custom message""" + api_id = "MyApi" + message = "Missing required properties" + exception = InvalidApiResourceException(api_id, message) + + self.assertEqual(exception.api_id, api_id) + self.assertIn("MyApi", str(exception)) + self.assertIn("Missing required properties", str(exception)) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = InvalidApiResourceException("MyApi") + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestOpenApiExtractionException(TestCase): + """Test OpenApiExtractionException""" + + def test_exception_with_message(self): + """Test exception with message""" + message = "Could not extract OpenAPI from template" + exception = OpenApiExtractionException(message) + + self.assertIn("Failed to extract OpenAPI definition", str(exception)) + self.assertIn("Could not extract OpenAPI from template", str(exception)) + + def test_exception_without_message(self): + """Test exception without message""" + exception = OpenApiExtractionException() + + self.assertIn("Failed to extract OpenAPI definition", str(exception)) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = OpenApiExtractionException("test") + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestTemplateTransformationException(TestCase): + """Test TemplateTransformationException""" + + def test_exception_with_message(self): + """Test exception with message""" + message = "Invalid SAM template structure" + exception = TemplateTransformationException(message) + + self.assertIn("Failed to transform SAM template", str(exception)) + self.assertIn("Invalid SAM template structure", str(exception)) + + def test_exception_without_message(self): + """Test exception without message""" + exception = TemplateTransformationException() + + self.assertIn("Failed to transform SAM template", str(exception)) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = TemplateTransformationException("test") + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestNoApiResourcesFoundException(TestCase): + """Test NoApiResourcesFoundException""" + + def test_exception_with_default_message(self): + """Test exception with default message""" + exception = NoApiResourcesFoundException() + + self.assertIn("No API resources found", str(exception)) + self.assertIn("AWS::Serverless::Api", str(exception)) + self.assertIn("AWS::Serverless::HttpApi", str(exception)) + + def test_exception_with_custom_message(self): + """Test exception with custom message""" + message = "Template is empty" + exception = NoApiResourcesFoundException(message) + + self.assertIn("No API resources found", str(exception)) + self.assertIn("Template is empty", str(exception)) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = NoApiResourcesFoundException() + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestMultipleApiResourcesException(TestCase): + """Test MultipleApiResourcesException""" + + def test_exception_with_single_api(self): + """Test exception with single API in list""" + api_ids = ["MyApi"] + exception = MultipleApiResourcesException(api_ids) + + self.assertEqual(exception.api_ids, api_ids) + self.assertIn("Multiple API resources found", str(exception)) + self.assertIn("MyApi", str(exception)) + self.assertIn("--api-logical-id", str(exception)) + + def test_exception_with_multiple_apis(self): + """Test exception with multiple APIs""" + api_ids = ["Api1", "Api2", "Api3"] + exception = MultipleApiResourcesException(api_ids) + + self.assertEqual(exception.api_ids, api_ids) + self.assertIn("Multiple API resources found", str(exception)) + self.assertIn("Api1", str(exception)) + self.assertIn("Api2", str(exception)) + self.assertIn("Api3", str(exception)) + self.assertIn("--api-logical-id", str(exception)) + + def test_exception_message_format(self): + """Test exception message formatting with commas""" + api_ids = ["FirstApi", "SecondApi"] + exception = MultipleApiResourcesException(api_ids) + + message = str(exception) + # Should have comma-separated API IDs + self.assertIn("FirstApi, SecondApi", message) + + def test_exception_is_generate_openapi_exception(self): + """Test that it inherits from GenerateOpenApiException""" + exception = MultipleApiResourcesException(["Api1"]) + self.assertIsInstance(exception, GenerateOpenApiException) + + +class TestExceptionInheritance(TestCase): + """Test exception inheritance hierarchy""" + + def test_all_exceptions_inherit_from_base(self): + """Test that all custom exceptions inherit from GenerateOpenApiException""" + exceptions_to_test = [ + ApiResourceNotFoundException("test"), + InvalidApiResourceException("test"), + OpenApiExtractionException("test"), + TemplateTransformationException("test"), + NoApiResourcesFoundException("test"), + MultipleApiResourcesException(["test"]), + ] + + for exception in exceptions_to_test: + with self.subTest(exception=type(exception).__name__): + self.assertIsInstance(exception, GenerateOpenApiException) + self.assertIsInstance(exception, UserException) diff --git a/tests/unit/commands/generate/test_generate_group.py b/tests/unit/commands/generate/test_generate_group.py new file mode 100644 index 0000000000..38ba98f087 --- /dev/null +++ b/tests/unit/commands/generate/test_generate_group.py @@ -0,0 +1,34 @@ +"""Unit tests for generate command group""" + +from unittest import TestCase +from unittest.mock import patch, Mock +from click.testing import CliRunner + +from samcli.commands.generate.generate import cli + + +class TestGenerateCommandGroup(TestCase): + """Test the generate command group""" + + def setUp(self): + self.runner = CliRunner() + + def test_generate_command_help(self): + """Test generate command shows help""" + result = self.runner.invoke(cli, ["--help"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Commands:", result.output) + + def test_generate_command_has_openapi_subcommand(self): + """Test that openapi subcommand is registered""" + result = self.runner.invoke(cli, ["--help"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("openapi", result.output) + self.assertIn("Generate OpenAPI specification", result.output) + + def test_generate_group_is_click_group(self): + """Test that cli is a Click group""" + self.assertTrue(hasattr(cli, "commands")) + self.assertIn("openapi", cli.commands) diff --git a/tests/unit/lib/generate/__init__.py b/tests/unit/lib/generate/__init__.py new file mode 100644 index 0000000000..080892a011 --- /dev/null +++ b/tests/unit/lib/generate/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for generate library +""" diff --git a/tests/unit/lib/generate/test_openapi_converter.py b/tests/unit/lib/generate/test_openapi_converter.py new file mode 100644 index 0000000000..394a7edb18 --- /dev/null +++ b/tests/unit/lib/generate/test_openapi_converter.py @@ -0,0 +1,53 @@ +"""Unit tests for OpenAPI converter""" + +from unittest import TestCase +from samcli.lib.generate.openapi_converter import OpenApiConverter + + +class TestOpenApiConverter(TestCase): + """Test OpenApiConverter class""" + + def test_swagger_to_openapi3_conversion(self): + """Test converting Swagger 2.0 to OpenAPI 3.0""" + swagger_doc = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0"}, + "paths": {"/test": {"get": {}}}, + "securityDefinitions": {"ApiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"}}, + } + + result = OpenApiConverter.swagger_to_openapi3(swagger_doc) + + # Version changed + self.assertEqual(result["openapi"], "3.0.0") + self.assertNotIn("swagger", result) + + # SecurityDefinitions moved + self.assertIn("components", result) + self.assertIn("securitySchemes", result["components"]) + self.assertEqual(result["components"]["securitySchemes"]["ApiKey"]["type"], "apiKey") + self.assertNotIn("securityDefinitions", result) + + def test_already_openapi3(self): + """Test that OpenAPI 3.0 docs are returned unchanged""" + openapi_doc = {"openapi": "3.0.0", "info": {"title": "Test"}, "paths": {}} + + result = OpenApiConverter.swagger_to_openapi3(openapi_doc) + + self.assertEqual(result["openapi"], "3.0.0") + self.assertEqual(result, openapi_doc) + + def test_invalid_input(self): + """Test handling of invalid input""" + self.assertIsNone(OpenApiConverter.swagger_to_openapi3(None)) + self.assertEqual(OpenApiConverter.swagger_to_openapi3([]), []) + self.assertEqual(OpenApiConverter.swagger_to_openapi3("string"), "string") + + def test_no_security_definitions(self): + """Test conversion without security definitions""" + swagger_doc = {"swagger": "2.0", "info": {"title": "Test"}, "paths": {}} + + result = OpenApiConverter.swagger_to_openapi3(swagger_doc) + + self.assertEqual(result["openapi"], "3.0.0") + self.assertNotIn("securityDefinitions", result) diff --git a/tests/unit/lib/generate/test_openapi_generator.py b/tests/unit/lib/generate/test_openapi_generator.py new file mode 100644 index 0000000000..156c2341cf --- /dev/null +++ b/tests/unit/lib/generate/test_openapi_generator.py @@ -0,0 +1,294 @@ +""" +Unit tests for OpenAPI Generator +""" + +from unittest import TestCase +from unittest.mock import Mock, patch, mock_open +from samcli.lib.generate.openapi_generator import OpenApiGenerator +from samcli.commands.generate.openapi.exceptions import ( + NoApiResourcesFoundException, + ApiResourceNotFoundException, + MultipleApiResourcesException, + OpenApiExtractionException, + TemplateTransformationException, +) + + +class TestOpenApiGenerator(TestCase): + def setUp(self): + self.template_file = "template.yaml" + self.generator = OpenApiGenerator(template_file=self.template_file) + + def test_init(self): + """Test OpenApiGenerator initialization""" + generator = OpenApiGenerator( + template_file="test.yaml", + api_logical_id="MyApi", + parameter_overrides={"Key": "Value"}, + region="us-east-1", + profile="default", + ) + + self.assertEqual(generator.template_file, "test.yaml") + self.assertEqual(generator.api_logical_id, "MyApi") + self.assertEqual(generator.parameter_overrides, {"Key": "Value"}) + self.assertEqual(generator.region, "us-east-1") + self.assertEqual(generator.profile, "default") + + @patch("builtins.open", new_callable=mock_open, read_data="Resources:\n MyApi:\n Type: AWS::Serverless::Api") + def test_load_template_success(self, mock_file): + """Test successful template loading""" + template = self.generator._load_template() + + self.assertIsInstance(template, dict) + self.assertIn("Resources", template) + + @patch("builtins.open", side_effect=FileNotFoundError()) + def test_load_template_file_not_found(self, mock_file): + """Test template loading with file not found""" + with self.assertRaises(OpenApiExtractionException) as context: + self.generator._load_template() + + self.assertIn("Template file not found", str(context.exception)) + + def test_find_api_resources(self): + """Test finding API resources in template""" + template = { + "Resources": { + "MyApi": {"Type": "AWS::Serverless::Api", "Properties": {}}, + "MyFunction": {"Type": "AWS::Serverless::Function", "Properties": {}}, + "MyHttpApi": {"Type": "AWS::Serverless::HttpApi", "Properties": {}}, + } + } + + api_resources = self.generator._find_api_resources(template) + + self.assertEqual(len(api_resources), 2) + self.assertIn("MyApi", api_resources) + self.assertIn("MyHttpApi", api_resources) + self.assertNotIn("MyFunction", api_resources) + + def test_find_api_resources_empty(self): + """Test finding API resources when none exist""" + template = { + "Resources": { + "MyFunction": {"Type": "AWS::Serverless::Function", "Properties": {}}, + } + } + + api_resources = self.generator._find_api_resources(template) + + self.assertEqual(len(api_resources), 0) + + def test_select_api_resource_single(self): + """Test selecting API when only one exists""" + api_resources = {"MyApi": {"Type": "AWS::Serverless::Api", "Properties": {}}} + + logical_id, resource = self.generator._select_api_resource(api_resources) + + self.assertEqual(logical_id, "MyApi") + self.assertEqual(resource["Type"], "AWS::Serverless::Api") + + def test_select_api_resource_specified(self): + """Test selecting specific API by logical ID""" + api_resources = { + "Api1": {"Type": "AWS::Serverless::Api", "Properties": {}}, + "Api2": {"Type": "AWS::Serverless::Api", "Properties": {}}, + } + + generator = OpenApiGenerator(template_file="test.yaml", api_logical_id="Api2") + logical_id, resource = generator._select_api_resource(api_resources) + + self.assertEqual(logical_id, "Api2") + + def test_select_api_resource_not_found(self): + """Test selecting API that doesn't exist""" + api_resources = { + "Api1": {"Type": "AWS::Serverless::Api", "Properties": {}}, + } + + generator = OpenApiGenerator(template_file="test.yaml", api_logical_id="Api2") + + with self.assertRaises(ApiResourceNotFoundException): + generator._select_api_resource(api_resources) + + def test_select_api_resource_multiple_no_id(self): + """Test selecting API when multiple exist and none specified""" + api_resources = { + "Api1": {"Type": "AWS::Serverless::Api", "Properties": {}}, + "Api2": {"Type": "AWS::Serverless::Api", "Properties": {}}, + } + + with self.assertRaises(MultipleApiResourcesException): + self.generator._select_api_resource(api_resources) + + def test_extract_existing_definition_body(self): + """Test extracting existing OpenAPI from DefinitionBody""" + resource = { + "Type": "AWS::Serverless::Api", + "Properties": { + "DefinitionBody": { + "swagger": "2.0", + "paths": {"/hello": {"get": {}}}, + } + }, + } + + openapi_doc = self.generator._extract_existing_definition(resource, "MyApi") + + self.assertIsNotNone(openapi_doc) + self.assertEqual(openapi_doc["swagger"], "2.0") + self.assertIn("paths", openapi_doc) + + def test_extract_existing_definition_none(self): + """Test extracting OpenAPI when none defined""" + resource = {"Type": "AWS::Serverless::Api", "Properties": {}} + + openapi_doc = self.generator._extract_existing_definition(resource, "MyApi") + + self.assertIsNone(openapi_doc) + + def test_validate_openapi_valid(self): + """Test validating valid OpenAPI document""" + openapi_doc = { + "swagger": "2.0", + "paths": {"/hello": {"get": {}}}, + } + + result = self.generator._validate_openapi(openapi_doc) + + self.assertTrue(result) + + def test_validate_openapi_missing_version(self): + """Test validating OpenAPI without version field""" + openapi_doc = { + "paths": {"/hello": {"get": {}}}, + } + + result = self.generator._validate_openapi(openapi_doc) + + self.assertFalse(result) + + def test_validate_openapi_missing_paths(self): + """Test validating OpenAPI without paths""" + openapi_doc = { + "swagger": "2.0", + } + + result = self.generator._validate_openapi(openapi_doc) + + self.assertFalse(result) + + def test_validate_openapi_invalid_type(self): + """Test validating invalid OpenAPI document type""" + result = self.generator._validate_openapi(None) + self.assertFalse(result) + + result = self.generator._validate_openapi([]) + self.assertFalse(result) + + result = self.generator._validate_openapi("string") + self.assertFalse(result) + + def test_has_implicit_api_true(self): + """Test detecting implicit API""" + template = { + "Resources": { + "MyFunction": { + "Type": "AWS::Serverless::Function", + "Properties": { + "Events": {"ApiEvent": {"Type": "Api", "Properties": {"Path": "/hello", "Method": "get"}}} + }, + } + } + } + + result = self.generator._has_implicit_api(template) + + self.assertTrue(result) + + def test_has_implicit_api_false(self): + """Test detecting no implicit API""" + template = { + "Resources": { + "MyFunction": { + "Type": "AWS::Serverless::Function", + "Properties": {"Events": {"S3Event": {"Type": "S3", "Properties": {"Bucket": "my-bucket"}}}}, + } + } + } + + result = self.generator._has_implicit_api(template) + + self.assertFalse(result) + + def test_get_api_resources_info(self): + """Test getting API resources information""" + template = { + "Resources": { + "MyApi": {"Type": "AWS::Serverless::Api", "Properties": {}}, + "MyHttpApi": {"Type": "AWS::Serverless::HttpApi", "Properties": {}}, + } + } + + with patch.object(self.generator, "_load_template", return_value=template): + info = self.generator.get_api_resources_info() + + self.assertEqual(len(info), 2) + self.assertEqual(info[0]["LogicalId"], "MyApi") + self.assertEqual(info[0]["Type"], "AWS::Serverless::Api") + + def test_get_api_resources_info_error(self): + """Test getting API resources info with error""" + with patch.object(self.generator, "_load_template", side_effect=Exception("Error")): + info = self.generator.get_api_resources_info() + + self.assertEqual(info, []) + + def test_find_api_resources_no_resources_key(self): + """Test finding API resources when Resources key missing""" + template = {} + api_resources = self.generator._find_api_resources(template) + self.assertEqual(len(api_resources), 0) + + def test_validate_openapi_with_openapi3(self): + """Test validating OpenAPI 3.0 document""" + openapi_doc = { + "openapi": "3.0.0", + "paths": {"/hello": {"get": {}}}, + } + result = self.generator._validate_openapi(openapi_doc) + self.assertTrue(result) + + def test_has_implicit_api_no_events(self): + """Test detecting no implicit API when no events""" + template = {"Resources": {"MyFunction": {"Type": "AWS::Serverless::Function", "Properties": {}}}} + result = self.generator._has_implicit_api(template) + self.assertFalse(result) + + def test_has_implicit_api_no_functions(self): + """Test detecting no implicit API when no functions""" + template = {"Resources": {}} + result = self.generator._has_implicit_api(template) + self.assertFalse(result) + + def test_extract_existing_definition_with_ref(self): + """Test extracting when DefinitionBody has Ref""" + resource = {"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"Ref": "SomeParameter"}}} + openapi_doc = self.generator._extract_existing_definition(resource, "MyApi") + # Refs are not expanded, so result should be the Ref dict + self.assertIsNotNone(openapi_doc) + + def test_find_api_resources_multiple_types(self): + """Test finding both RestApi and HttpApi""" + template = { + "Resources": { + "RestApi": {"Type": "AWS::Serverless::Api", "Properties": {}}, + "HttpApi": {"Type": "AWS::Serverless::HttpApi", "Properties": {}}, + "Table": {"Type": "AWS::DynamoDB::Table", "Properties": {}}, + } + } + api_resources = self.generator._find_api_resources(template) + self.assertEqual(len(api_resources), 2) + self.assertIn("RestApi", api_resources) + self.assertIn("HttpApi", api_resources)