diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 17deb8eaa9b..ba426d8388e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -2206,6 +2206,162 @@ def configure_openapi( openapi_extensions=openapi_extensions, ) + def configure_openapi_merge( + self, + path: str, + pattern: str | list[str] = "handler.py", + exclude: list[str] | None = None, + resolver_name: str = "app", + recursive: bool = False, + title: str = DEFAULT_OPENAPI_TITLE, + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, + summary: str | None = None, + description: str | None = None, + tags: list[Tag | str] | None = None, + servers: list[Server] | None = None, + terms_of_service: str | None = None, + contact: Contact | None = None, + license_info: License | None = None, + security_schemes: dict[str, SecurityScheme] | None = None, + security: list[dict[str, list[str]]] | None = None, + external_documentation: ExternalDocumentation | None = None, + openapi_extensions: dict[str, Any] | None = None, + on_conflict: Literal["warn", "error", "first", "last"] = "warn", + ): + """Configure OpenAPI merge to generate a unified schema from multiple Lambda handlers. + + This method discovers resolver instances across multiple Python files and merges + their OpenAPI schemas into a single unified specification. Useful for micro-function + architectures where each Lambda has its own resolver. + + Parameters + ---------- + path : str + Root directory path to search for resolver files. + pattern : str | list[str], optional + Glob pattern(s) to match handler files. Default is "handler.py". + exclude : list[str], optional + Patterns to exclude from search. Default excludes tests, __pycache__, and .venv. + resolver_name : str, optional + Name of the resolver variable in handler files. Default is "app". + recursive : bool, optional + Whether to search recursively in subdirectories. Default is False. + title : str + The title of the unified API. + version : str + The version of the OpenAPI document. + openapi_version : str, default = "3.1.0" + The version of the OpenAPI Specification. + summary : str, optional + A short summary of what the application does. + description : str, optional + A verbose explanation of the application behavior. + tags : list[Tag | str], optional + A list of tags used by the specification with additional metadata. + servers : list[Server], optional + An array of Server Objects for connectivity information. + terms_of_service : str, optional + A URL to the Terms of Service for the API. + contact : Contact, optional + The contact information for the exposed API. + license_info : License, optional + The license information for the exposed API. + security_schemes : dict[str, SecurityScheme], optional + Security schemes available in the specification. + security : list[dict[str, list[str]]], optional + Security mechanisms applied globally across the API. + external_documentation : ExternalDocumentation, optional + A link to external documentation for the API. + openapi_extensions : dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. + on_conflict : str, optional + Strategy for handling conflicts when the same path+method is defined + in multiple schemas. Options: "warn" (default), "error", "first", "last". + + Example + ------- + >>> from aws_lambda_powertools.event_handler import APIGatewayRestResolver + >>> + >>> app = APIGatewayRestResolver() + >>> app.configure_openapi_merge( + ... path="./functions", + ... pattern="handler.py", + ... exclude=["**/tests/**"], + ... resolver_name="app", + ... title="My Unified API", + ... version="1.0.0", + ... ) + + See Also + -------- + configure_openapi : Configure OpenAPI for a single resolver + enable_swagger : Enable Swagger UI + """ + from aws_lambda_powertools.event_handler.openapi.merge import OpenAPIMerge + + if exclude is None: + exclude = ["**/tests/**", "**/__pycache__/**", "**/.venv/**"] + + self._openapi_merge = OpenAPIMerge( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + security_schemes=security_schemes, + security=security, + external_documentation=external_documentation, + openapi_extensions=openapi_extensions, + on_conflict=on_conflict, + ) + self._openapi_merge.discover( + path=path, + pattern=pattern, + exclude=exclude, + resolver_name=resolver_name, + recursive=recursive, + ) + + def get_openapi_merge_schema(self) -> dict[str, Any]: + """Get the merged OpenAPI schema from multiple Lambda handlers. + + Returns + ------- + dict[str, Any] + The merged OpenAPI schema. + + Raises + ------ + RuntimeError + If configure_openapi_merge has not been called. + """ + if not hasattr(self, "_openapi_merge") or self._openapi_merge is None: + raise RuntimeError("configure_openapi_merge must be called before get_openapi_merge_schema") + return self._openapi_merge.get_openapi_schema() + + def get_openapi_merge_json_schema(self) -> str: + """Get the merged OpenAPI schema as JSON from multiple Lambda handlers. + + Returns + ------- + str + The merged OpenAPI schema as a JSON string. + + Raises + ------ + RuntimeError + If configure_openapi_merge has not been called. + """ + if not hasattr(self, "_openapi_merge") or self._openapi_merge is None: + raise RuntimeError("configure_openapi_merge must be called before get_openapi_merge_json_schema") + return self._openapi_merge.get_openapi_json_schema() + def enable_swagger( self, *, @@ -2312,32 +2468,38 @@ def swagger_handler(): openapi_servers = servers or [Server(url=(base_path or "/"))] - spec = self.get_openapi_schema( - title=title, - version=version, - openapi_version=openapi_version, - summary=summary, - description=description, - tags=tags, - servers=openapi_servers, - terms_of_service=terms_of_service, - contact=contact, - license_info=license_info, - security_schemes=security_schemes, - security=security, - external_documentation=external_documentation, - openapi_extensions=openapi_extensions, - ) + # Use merged schema if configure_openapi_merge was called, otherwise use regular schema + if hasattr(self, "_openapi_merge") and self._openapi_merge is not None: + # Get merged schema as JSON string (already properly serialized) + escaped_spec = self._openapi_merge.get_openapi_json_schema().replace(" or similar tags. Escaping the forward slash in or similar tags. Escaping the forward slash in bool: + """Check if an AST node is a call to a resolver class.""" + if not isinstance(node, ast.Call): + return False + func = node.func + if isinstance(func, ast.Name) and func.id in RESOLVER_CLASSES: + return True + if isinstance(func, ast.Attribute) and func.attr in RESOLVER_CLASSES: # pragma: no cover + return True + return False # pragma: no cover + + +def _file_has_resolver(file_path: Path, resolver_name: str) -> bool: + """Check if a Python file contains a resolver instance using AST.""" + try: + source = file_path.read_text(encoding="utf-8") + tree = ast.parse(source, filename=str(file_path)) + except (SyntaxError, UnicodeDecodeError): + return False + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == resolver_name: + if _is_resolver_call(node.value): + return True + return False + + +def _is_excluded(file_path: Path, root: Path, exclude_patterns: list[str]) -> bool: + """Check if a file matches any exclusion pattern.""" + relative_str = str(file_path.relative_to(root)) + + for pattern in exclude_patterns: + if pattern.startswith("**/"): + sub_pattern = pattern[3:] + if fnmatch.fnmatch(relative_str, pattern) or fnmatch.fnmatch(file_path.name, sub_pattern): + return True + # Check directory parts - remove trailing glob patterns + clean_pattern = sub_pattern.replace("/**", "").replace("/*", "") + for part in file_path.relative_to(root).parts: + if fnmatch.fnmatch(part, clean_pattern): # pragma: no cover + return True + elif fnmatch.fnmatch(relative_str, pattern) or fnmatch.fnmatch(file_path.name, pattern): # pragma: no cover + return True + return False + + +def _get_glob_pattern(pat: str, recursive: bool) -> str: + """Get the glob pattern based on recursive flag.""" + if recursive and not pat.startswith("**/"): + return f"**/{pat}" + if not recursive and pat.startswith("**/"): + return pat[3:] # Strip **/ prefix + return pat + + +def _discover_resolver_files( + path: str | Path, + pattern: str | list[str], + exclude: list[str], + resolver_name: str, + recursive: bool = False, +) -> list[Path]: + """Discover Python files containing resolver instances.""" + root = Path(path).resolve() + if not root.exists(): + raise FileNotFoundError(f"Path does not exist: {root}") + + patterns = [pattern] if isinstance(pattern, str) else pattern + found_files: set[Path] = set() + + for pat in patterns: + glob_pattern = _get_glob_pattern(pat, recursive) + for file_path in root.glob(glob_pattern): + if ( + file_path.is_file() + and not _is_excluded(file_path, root, exclude) + and _file_has_resolver(file_path, resolver_name) + ): + found_files.add(file_path) + + return sorted(found_files) + + +def _load_resolver(file_path: Path, resolver_name: str) -> Any: + """Load a resolver instance from a Python file.""" + file_path = Path(file_path).resolve() + module_name = f"_powertools_openapi_merge_{file_path.stem}_{id(file_path)}" + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: # pragma: no cover + raise ImportError(f"Cannot load module from {file_path}") + + module = importlib.util.module_from_spec(spec) + module_dir = str(file_path.parent) + original_path = sys.path.copy() + + try: + if module_dir not in sys.path: + sys.path.insert(0, module_dir) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + if not hasattr(module, resolver_name): + raise AttributeError(f"Resolver '{resolver_name}' not found in {file_path}.") + return getattr(module, resolver_name) + finally: + sys.path = original_path + sys.modules.pop(module_name, None) + + +def _model_to_dict(obj: Any) -> Any: + """Convert Pydantic model to dict if needed.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(by_alias=True, exclude_none=True) + return obj # pragma: no cover + + +class OpenAPIMerge: + """ + Discover and merge OpenAPI schemas from multiple Lambda handlers. + + This class is designed for micro-functions architectures where you have multiple + Lambda functions, each with its own resolver, and need to generate a unified + OpenAPI specification. It's particularly useful for: + + - CI/CD pipelines to generate and publish unified API documentation + - Build-time schema generation for API Gateway imports + - Creating a dedicated Lambda that serves the consolidated OpenAPI spec + + The class uses AST analysis to detect resolver instances without importing modules, + making discovery fast and safe. + + Parameters + ---------- + title : str + The title of the unified API. + version : str + The version of the API (e.g., "1.0.0"). + openapi_version : str, default "3.1.0" + The OpenAPI specification version. + summary : str, optional + A short summary of the API. + description : str, optional + A detailed description of the API. + tags : list[Tag | str], optional + Tags for API documentation organization. + servers : list[Server], optional + Server objects for API connectivity information. + terms_of_service : str, optional + URL to the Terms of Service. + contact : Contact, optional + Contact information for the API. + license_info : License, optional + License information for the API. + security_schemes : dict[str, SecurityScheme], optional + Security scheme definitions. + security : list[dict[str, list[str]]], optional + Global security requirements. + external_documentation : ExternalDocumentation, optional + Link to external documentation. + openapi_extensions : dict[str, Any], optional + OpenAPI specification extensions (x-* fields). + on_conflict : Literal["warn", "error", "first", "last"], default "warn" + Strategy when the same path+method is defined in multiple handlers: + - "warn": Log warning and keep first definition + - "error": Raise OpenAPIMergeError + - "first": Silently keep first definition + - "last": Use last definition (override) + + Example + ------- + **CI/CD Pipeline - Generate unified schema at build time:** + + >>> from aws_lambda_powertools.event_handler.openapi import OpenAPIMerge + >>> + >>> merge = OpenAPIMerge( + ... title="My Unified API", + ... version="1.0.0", + ... description="Consolidated API from multiple Lambda functions", + ... ) + >>> merge.discover( + ... path="./src/functions", + ... pattern="**/handler.py", + ... exclude=["**/tests/**"], + ... ) + >>> schema_json = merge.get_openapi_json_schema() + >>> + >>> # Write to file for API Gateway import or documentation + >>> with open("openapi.json", "w") as f: + ... f.write(schema_json) + + **Dedicated OpenAPI Lambda - Serve unified spec at runtime:** + + >>> from aws_lambda_powertools.event_handler import APIGatewayRestResolver + >>> + >>> app = APIGatewayRestResolver() + >>> app.configure_openapi_merge( + ... path="./functions", + ... pattern="**/handler.py", + ... title="My API", + ... version="1.0.0", + ... ) + >>> app.enable_swagger(path="/docs") # Swagger UI with merged schema + >>> + >>> def handler(event, context): + ... return app.resolve(event, context) + + See Also + -------- + OpenAPIMergeError : Exception raised on merge conflicts when on_conflict="error" + """ + + def __init__( + self, + *, + title: str = DEFAULT_OPENAPI_TITLE, + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, + summary: str | None = None, + description: str | None = None, + tags: list[Tag | str] | None = None, + servers: list[Server] | None = None, + terms_of_service: str | None = None, + contact: Contact | None = None, + license_info: License | None = None, + security_schemes: dict[str, SecurityScheme] | None = None, + security: list[dict[str, list[str]]] | None = None, + external_documentation: ExternalDocumentation | None = None, + openapi_extensions: dict[str, Any] | None = None, + on_conflict: ConflictStrategy = "warn", + ): + self._config = OpenAPIConfig( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + security_schemes=security_schemes, + security=security, + external_documentation=external_documentation, + openapi_extensions=openapi_extensions, + ) + self._schemas: list[dict[str, Any]] = [] + self._discovered_files: list[Path] = [] + self._resolver_name: str = "app" + self._on_conflict = on_conflict + self._cached_schema: dict[str, Any] | None = None + + def discover( + self, + path: str | Path, + pattern: str | list[str] = "handler.py", + exclude: list[str] | None = None, + resolver_name: str = "app", + recursive: bool = False, + ) -> list[Path]: + """ + Discover resolver files in the specified path using glob patterns. + + This method scans the directory tree for Python files matching the pattern, + then uses AST analysis to identify files containing resolver instances. + + Parameters + ---------- + path : str | Path + Root directory to search for handler files. + pattern : str | list[str], default "handler.py" + Glob pattern(s) to match handler files. + exclude : list[str], optional + Patterns to exclude. Defaults to ["**/tests/**", "**/__pycache__/**", "**/.venv/**"]. + resolver_name : str, default "app" + Variable name of the resolver instance in handler files. + recursive : bool, default False + Whether to search recursively in subdirectories. + + Returns + ------- + list[Path] + List of discovered files containing resolver instances. + + Example + ------- + >>> merge = OpenAPIMerge(title="API", version="1.0.0") + >>> files = merge.discover( + ... path="./src", + ... pattern=["handler.py", "api.py"], + ... exclude=["**/tests/**", "**/legacy/**"], + ... recursive=True, + ... ) + >>> print(f"Found {len(files)} handlers") + """ + exclude = exclude or ["**/tests/**", "**/__pycache__/**", "**/.venv/**"] + self._resolver_name = resolver_name + self._discovered_files = _discover_resolver_files(path, pattern, exclude, resolver_name, recursive) + return self._discovered_files + + def add_file(self, file_path: str | Path, resolver_name: str | None = None) -> None: + """Add a specific file to be included in the merge. + + Note: Must be called before get_openapi_schema(). Adding files after + schema generation will not affect the cached result. + """ + path = Path(file_path).resolve() + if path not in self._discovered_files: + self._discovered_files.append(path) + if resolver_name: + self._resolver_name = resolver_name + + def add_schema(self, schema: dict[str, Any]) -> None: + """Add a pre-generated OpenAPI schema to be merged. + + Note: Must be called before get_openapi_schema(). Adding schemas after + schema generation will not affect the cached result. + """ + self._schemas.append(_model_to_dict(schema)) + + def get_openapi_schema(self) -> dict[str, Any]: + """ + Generate the merged OpenAPI schema as a dictionary. + + Loads all discovered resolver files, extracts their OpenAPI schemas, + and merges them into a single unified specification. + + The schema is cached after the first generation for performance. + + Returns + ------- + dict[str, Any] + The merged OpenAPI schema. + + Raises + ------ + OpenAPIMergeError + If on_conflict="error" and duplicate path+method combinations are found. + """ + if self._cached_schema is not None: + return self._cached_schema + + # Load schemas from discovered files + for file_path in self._discovered_files: + try: + resolver = _load_resolver(file_path, self._resolver_name) + if hasattr(resolver, "get_openapi_schema"): + self._schemas.append(_model_to_dict(resolver.get_openapi_schema())) + except (ImportError, AttributeError, FileNotFoundError) as e: # pragma: no cover + logger.warning(f"Failed to load resolver from {file_path}: {e}") + + self._cached_schema = self._merge_schemas() + return self._cached_schema + + def get_openapi_json_schema(self) -> str: + """ + Generate the merged OpenAPI schema as a JSON string. + + This is the recommended method for CI/CD pipelines and build-time + schema generation, as the output can be directly written to a file + or used for API Gateway imports. + + Returns + ------- + str + The merged OpenAPI schema as formatted JSON. + + Example + ------- + >>> merge = OpenAPIMerge(title="API", version="1.0.0") + >>> merge.discover(path="./functions", pattern="**/handler.py") + >>> json_schema = merge.get_openapi_json_schema() + >>> with open("openapi.json", "w") as f: + ... f.write(json_schema) + """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI + + schema = self.get_openapi_schema() + return model_json(OpenAPI(**schema), by_alias=True, exclude_none=True, indent=2) + + @property + def discovered_files(self) -> list[Path]: + """Get the list of discovered resolver files.""" + return self._discovered_files.copy() + + def _merge_schemas(self) -> dict[str, Any]: + """Merge all schemas into a single OpenAPI schema.""" + cfg = self._config + + # Build base schema + merged: dict[str, Any] = { + "openapi": cfg.openapi_version, + "info": {"title": cfg.title, "version": cfg.version}, + "servers": [_model_to_dict(s) for s in cfg.servers] if cfg.servers else [{"url": "/"}], + } + + # Add optional info fields + self._add_optional_info_fields(merged, cfg) + + # Merge paths and components + merged_paths: dict[str, Any] = {} + merged_components: dict[str, dict[str, Any]] = {} + + for schema in self._schemas: + self._merge_paths(schema.get("paths", {}), merged_paths) + self._merge_components(schema.get("components", {}), merged_components) + + # Add security schemes from config + if cfg.security_schemes: + merged_components.setdefault("securitySchemes", {}).update(cfg.security_schemes) + + if merged_paths: + merged["paths"] = merged_paths + if merged_components: + merged["components"] = merged_components + + # Merge tags + if merged_tags := self._merge_tags(): + merged["tags"] = merged_tags + + return merged + + def _add_optional_info_fields(self, merged: dict[str, Any], cfg: OpenAPIConfig) -> None: + """Add optional fields from config to the merged schema.""" + if cfg.summary: + merged["info"]["summary"] = cfg.summary + if cfg.description: + merged["info"]["description"] = cfg.description + if cfg.terms_of_service: + merged["info"]["termsOfService"] = cfg.terms_of_service + if cfg.contact: + merged["info"]["contact"] = _model_to_dict(cfg.contact) + if cfg.license_info: + merged["info"]["license"] = _model_to_dict(cfg.license_info) + if cfg.security: + merged["security"] = cfg.security + if cfg.external_documentation: + merged["externalDocs"] = _model_to_dict(cfg.external_documentation) + if cfg.openapi_extensions: + merged.update(cfg.openapi_extensions) + + def _merge_paths(self, source_paths: dict[str, Any], target: dict[str, Any]) -> None: + """Merge paths from source into target.""" + for path, path_item in source_paths.items(): + if path not in target: + target[path] = path_item + else: + for method, operation in path_item.items(): + if method not in target[path]: + target[path][method] = operation + else: + self._handle_conflict(method, path, target, operation) + + def _handle_conflict(self, method: str, path: str, target: dict, operation: Any) -> None: + """Handle path/method conflict based on strategy.""" + msg = f"Conflict: {method.upper()} {path} is defined in multiple schemas" + if self._on_conflict == "error": + raise OpenAPIMergeError(msg) + elif self._on_conflict == "warn": + logger.warning(f"{msg}. Keeping first definition.") + elif self._on_conflict == "last": + target[path][method] = operation + + def _merge_components(self, source: dict[str, Any], target: dict[str, dict[str, Any]]) -> None: + """Merge components from source into target. + + Note: Components with the same name are silently overwritten (last wins). + This is intentional as component conflicts are typically user errors + (e.g., two handlers defining different 'User' schemas). + """ + for component_type, components in source.items(): + target.setdefault(component_type, {}).update(components) + + def _merge_tags(self) -> list[dict[str, Any]]: + """Merge tags from config and schemas.""" + tags_map: dict[str, dict[str, Any]] = {} + + # Config tags first + for tag in self._config.tags or []: + if isinstance(tag, str): + tags_map[tag] = {"name": tag} + else: + tag_dict = _model_to_dict(tag) + tags_map[tag_dict["name"]] = tag_dict + + # Schema tags (don't override config) + for schema in self._schemas: + for tag in schema.get("tags", []): + name = tag["name"] if isinstance(tag, dict) else tag + if name not in tags_map: + tags_map[name] = tag if isinstance(tag, dict) else {"name": tag} # pragma: no cover + + return list(tags_map.values()) diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/__init__.py b/tests/functional/event_handler/_pydantic/merge_handlers/__init__.py new file mode 100644 index 00000000000..7e881bf743e --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/__init__.py @@ -0,0 +1 @@ +# Sample handlers for OpenAPI merge tests diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/alb_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/alb_handler.py new file mode 100644 index 00000000000..78c39b4eb34 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/alb_handler.py @@ -0,0 +1,17 @@ +"""Sample ALB resolver handler for testing.""" + +from aws_lambda_powertools.event_handler import ALBResolver + +app = ALBResolver() + + +@app.get("/alb/health") +def health_check(): + """ALB health check endpoint.""" + return {"status": "healthy", "resolver": "alb"} + + +@app.post("/alb/process") +def process_data(): + """ALB process endpoint.""" + return {"processed": True} diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/conflict_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/conflict_handler.py new file mode 100644 index 00000000000..8f9a06af797 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/conflict_handler.py @@ -0,0 +1,15 @@ +"""Handler with conflicting route (same as users_handler).""" + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver() + + +@app.get("/users") +def get_users_conflict(): + """This conflicts with users_handler.py.""" + return {"conflict": True} + + +def handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/http_api_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/http_api_handler.py new file mode 100644 index 00000000000..952cb704a10 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/http_api_handler.py @@ -0,0 +1,17 @@ +"""Sample HTTP API resolver handler for testing.""" + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver + +app = APIGatewayHttpResolver() + + +@app.get("/http/items") +def list_items(): + """List items via HTTP API.""" + return {"items": []} + + +@app.get("/http/items/") +def get_item(item_id: str): + """Get item by ID.""" + return {"item_id": item_id} diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/lambda_url_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/lambda_url_handler.py new file mode 100644 index 00000000000..12a7bb8b181 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/lambda_url_handler.py @@ -0,0 +1,17 @@ +"""Sample Lambda Function URL resolver handler for testing.""" + +from aws_lambda_powertools.event_handler import LambdaFunctionUrlResolver + +app = LambdaFunctionUrlResolver() + + +@app.get("/lambda-url/status") +def get_status(): + """Get Lambda URL status.""" + return {"status": "ok", "resolver": "lambda_url"} + + +@app.post("/lambda-url/webhook") +def webhook(): + """Webhook endpoint.""" + return {"received": True} diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/no_resolver.py b/tests/functional/event_handler/_pydantic/merge_handlers/no_resolver.py new file mode 100644 index 00000000000..6bb35200b82 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/no_resolver.py @@ -0,0 +1,5 @@ +# This file has no resolver - used to test discovery filtering + + +def helper_function(): + return "I'm just a helper" diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/orders_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/orders_handler.py new file mode 100644 index 00000000000..4681bc2b651 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/orders_handler.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver(enable_validation=True) + + +class Order(BaseModel): + id: int + user_id: int + total: float + + +@app.get("/orders") +def get_orders() -> list[Order]: + return [] + + +@app.get("/orders/") +def get_order(order_id: int) -> Order: + return Order(id=order_id, user_id=1, total=99.99) + + +def handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/rest_api_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/rest_api_handler.py new file mode 100644 index 00000000000..6f363510b77 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/rest_api_handler.py @@ -0,0 +1,17 @@ +"""Sample REST API resolver handler for testing.""" + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver() + + +@app.get("/rest/users") +def list_users(): + """List users via REST API.""" + return {"users": []} + + +@app.post("/rest/users") +def create_user(): + """Create user via REST API.""" + return {"created": True} diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/tagged_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/tagged_handler.py new file mode 100644 index 00000000000..c06686aab8c --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/tagged_handler.py @@ -0,0 +1,25 @@ +"""Handler with tags for testing tag merging.""" + +from __future__ import annotations + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver() + + +@app.get("/tagged") +def tagged_endpoint(): + """Endpoint in tagged handler.""" + return {"tagged": True} + + +# Override get_openapi_schema to include tags +_original_get_openapi_schema = app.get_openapi_schema + + +def get_openapi_schema_with_tags(**kwargs): + kwargs.setdefault("tags", ["handler-tag"]) + return _original_get_openapi_schema(**kwargs) + + +app.get_openapi_schema = get_openapi_schema_with_tags diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/users_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/users_handler.py new file mode 100644 index 00000000000..8813511051d --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/users_handler.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver(enable_validation=True) + + +class User(BaseModel): + id: int + name: str + email: str + + +@app.get("/users") +def get_users() -> list[User]: + return [ + User(id=1, name="John", email="john@example.com"), + ] + + +@app.post("/users") +def create_user(user: User) -> User: + return user + + +def handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/zzz_conflict_last_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/zzz_conflict_last_handler.py new file mode 100644 index 00000000000..7a4d16d7f43 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/zzz_conflict_last_handler.py @@ -0,0 +1,11 @@ +"""Handler with conflicting route for testing on_conflict='last'.""" + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app = APIGatewayRestResolver() + + +@app.get("/users", summary="Get users from conflict_last") +def get_users_last(): + """This conflicts with users_handler.py - used to test on_conflict='last'.""" + return {"source": "conflict_last"} diff --git a/tests/functional/event_handler/_pydantic/test_openapi_merge.py b/tests/functional/event_handler/_pydantic/test_openapi_merge.py new file mode 100644 index 00000000000..b4dc1d70232 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_openapi_merge.py @@ -0,0 +1,369 @@ +"""Tests for OpenAPI merge functionality.""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi import OpenAPIMerge, OpenAPIMergeError + +MERGE_HANDLERS_PATH = Path(__file__).parent / "merge_handlers" + + +def test_openapi_merge_discover_non_recursive(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="Non-Recursive API", version="1.0.0") + + # WHEN discovering resolvers without recursion + files = merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="*_handler.py", + recursive=False, + ) + + # THEN it should find handlers in the root directory only + assert len(files) > 0 + for f in files: + assert f.parent == MERGE_HANDLERS_PATH + + +def test_openapi_merge_discover_and_get_schema(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="My API", version="1.0.0") + + # WHEN discovering resolvers + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="**/*_handler.py", + exclude=["**/conflict_handler.py"], + resolver_name="app", + ) + + # THEN it should generate merged schema + schema = merge.get_openapi_schema() + assert schema["info"]["title"] == "My API" + assert schema["info"]["version"] == "1.0.0" + assert "/users" in schema["paths"] + assert "/orders" in schema["paths"] + + +def test_openapi_merge_get_json_schema(): + # GIVEN an OpenAPIMerge with discovered resolvers + merge = OpenAPIMerge(title="JSON API", version="2.0.0") + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="**/users_handler.py", + ) + + # WHEN getting JSON schema + json_schema = merge.get_openapi_json_schema() + + # THEN it should be valid JSON + parsed = json.loads(json_schema) + assert parsed["info"]["title"] == "JSON API" + assert "/users" in parsed["paths"] + + +def test_openapi_merge_discovered_files(): + # GIVEN an OpenAPIMerge with discovered files + merge = OpenAPIMerge(title="Test", version="1.0.0") + merge.discover(path=MERGE_HANDLERS_PATH, pattern="**/users_handler.py") + + # WHEN getting discovered files + files = merge.discovered_files + + # THEN it should return the list + assert len(files) == 1 + assert files[0].name == "users_handler.py" + + +def test_openapi_merge_on_conflict_error(): + # GIVEN handlers with conflicting routes + merge = OpenAPIMerge( + title="Conflict API", + version="1.0.0", + on_conflict="error", + ) + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="**/*_handler.py", # includes conflict_handler.py + resolver_name="app", + ) + + # WHEN/THEN getting schema should raise + with pytest.raises(OpenAPIMergeError, match="Conflict"): + merge.get_openapi_schema() + + +def test_openapi_merge_on_conflict_warn(): + # GIVEN handlers with conflicting routes + merge = OpenAPIMerge(title="Warn API", version="1.0.0", on_conflict="warn") + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="**/*_handler.py", + resolver_name="app", + ) + + # WHEN getting schema with mock logger + with patch("aws_lambda_powertools.event_handler.openapi.merge.logger") as mock_logger: + schema = merge.get_openapi_schema() + + # THEN it should log warning and keep first + mock_logger.warning.assert_called() + assert "/users" in schema["paths"] + + +def test_openapi_merge_on_conflict_last(): + # GIVEN handlers with conflicting routes (zzz_ prefix ensures it's discovered last) + merge = OpenAPIMerge(title="Last API", version="1.0.0", on_conflict="last") + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern=["**/users_handler.py", "**/zzz_conflict_last_handler.py"], + resolver_name="app", + ) + + # WHEN getting schema + schema = merge.get_openapi_schema() + + # THEN it should use last definition + assert "/users" in schema["paths"] + assert schema["paths"]["/users"]["get"]["summary"] == "Get users from conflict_last" + + +def test_configure_openapi_merge_and_get_schema(): + # GIVEN a resolver + app = APIGatewayRestResolver() + + # WHEN configuring openapi merge + app.configure_openapi_merge( + path=str(MERGE_HANDLERS_PATH), + pattern="**/*_handler.py", + exclude=["**/conflict_handler.py"], + resolver_name="app", + title="Resolver Merge API", + version="1.0.0", + ) + + # THEN it should return merged schema + schema = app.get_openapi_merge_schema() + assert schema["info"]["title"] == "Resolver Merge API" + assert "/users" in schema["paths"] + assert "/orders" in schema["paths"] + + +def test_configure_openapi_merge_json_schema(): + # GIVEN a configured merge + app = APIGatewayRestResolver() + app.configure_openapi_merge( + path=str(MERGE_HANDLERS_PATH), + pattern="**/users_handler.py", + title="JSON API", + version="1.0.0", + ) + + # WHEN getting JSON schema + json_schema = app.get_openapi_merge_json_schema() + + # THEN it should be valid JSON + parsed = json.loads(json_schema) + assert parsed["info"]["title"] == "JSON API" + + +def test_get_openapi_merge_schema_without_configure_raises(): + # GIVEN a resolver without configure_openapi_merge + app = APIGatewayRestResolver() + + # WHEN/THEN should raise + with pytest.raises(RuntimeError, match="configure_openapi_merge must be called"): + app.get_openapi_merge_schema() + + +def test_get_openapi_merge_json_schema_without_configure_raises(): + # GIVEN a resolver without configure_openapi_merge + app = APIGatewayRestResolver() + + # WHEN/THEN should raise + with pytest.raises(RuntimeError, match="configure_openapi_merge must be called"): + app.get_openapi_merge_json_schema() + + +def test_enable_swagger_uses_merged_schema(): + # GIVEN a resolver with configure_openapi_merge + app = APIGatewayRestResolver() + app.configure_openapi_merge( + path=str(MERGE_HANDLERS_PATH), + pattern="**/*_handler.py", + exclude=["**/conflict_handler.py"], + resolver_name="app", + title="Swagger Merge API", + version="2.0.0", + ) + app.enable_swagger(path="/swagger") + + # WHEN calling swagger endpoint with format=json + event = { + "httpMethod": "GET", + "path": "/swagger", + "queryStringParameters": {"format": "json"}, + "headers": {}, + "requestContext": {"stage": "prod", "path": "/prod/swagger"}, + } + response = app.resolve(event, {}) + + # THEN it should return merged schema + body = json.loads(response["body"]) + assert body["info"]["title"] == "Swagger Merge API" + assert "/users" in body["paths"] + assert "/orders" in body["paths"] + + +def test_enable_swagger_without_merge_uses_regular_schema(): + # GIVEN a resolver without configure_openapi_merge + app = APIGatewayRestResolver() + + @app.get("/local") + def local_endpoint(): + return {"local": True} + + app.enable_swagger(path="/swagger", title="Local API", version="1.0.0") + + # WHEN calling swagger endpoint + event = { + "httpMethod": "GET", + "path": "/swagger", + "queryStringParameters": {"format": "json"}, + "headers": {}, + "requestContext": {"stage": "prod", "path": "/prod/swagger"}, + } + response = app.resolve(event, {}) + + # THEN it should return local schema only + body = json.loads(response["body"]) + assert body["info"]["title"] == "Local API" + assert "/local" in body["paths"] + assert "/users" not in body["paths"] + + +def test_openapi_merge_with_all_optional_fields(): + # GIVEN an OpenAPIMerge with all optional config fields + from aws_lambda_powertools.event_handler.openapi.models import ( + Contact, + ExternalDocumentation, + License, + Server, + Tag, + ) + + merge = OpenAPIMerge( + title="Full Config API", + version="1.0.0", + summary="API summary", + description="API description", + terms_of_service="https://example.com/tos", + contact=Contact(name="Support", email="support@example.com"), + license_info=License(name="MIT"), + servers=[Server(url="https://api.example.com")], + tags=[Tag(name="users", description="User operations"), "orders"], + security=[{"api_key": []}], + security_schemes={"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}}, + external_documentation=ExternalDocumentation(url="https://docs.example.com"), + openapi_extensions={"x-custom": "value"}, + ) + merge.discover(path=MERGE_HANDLERS_PATH, pattern="**/users_handler.py") + + # WHEN getting schema + schema = merge.get_openapi_schema() + + # THEN all optional fields should be present + assert schema["info"]["summary"] == "API summary" + assert schema["info"]["description"] == "API description" + assert schema["info"]["termsOfService"] == "https://example.com/tos" + assert schema["info"]["contact"]["name"] == "Support" + assert schema["info"]["license"]["name"] == "MIT" + assert schema["servers"][0]["url"] == "https://api.example.com" + assert schema["security"] == [{"api_key": []}] + assert "api_key" in schema["components"]["securitySchemes"] + assert "https://docs.example.com" in str(schema["externalDocs"]["url"]) + assert schema["x-custom"] == "value" + # Tags should include both config tags and schema tags + tag_names = [t["name"] for t in schema["tags"]] + assert "users" in tag_names + assert "orders" in tag_names + + +def test_openapi_merge_add_file(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="Add File API", version="1.0.0") + + # WHEN adding a file manually + handler_path = MERGE_HANDLERS_PATH / "users_handler.py" + merge.add_file(handler_path) + + # THEN it should be in discovered files + assert handler_path.resolve() in merge.discovered_files + + # AND adding the same file again should not duplicate + merge.add_file(handler_path) + assert len([f for f in merge.discovered_files if f.name == "users_handler.py"]) == 1 + + +def test_openapi_merge_add_file_with_resolver_name(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="Add File API", version="1.0.0") + + # WHEN adding a file with custom resolver name + handler_path = MERGE_HANDLERS_PATH / "users_handler.py" + merge.add_file(handler_path, resolver_name="app") + + # THEN it should update the resolver name + schema = merge.get_openapi_schema() + assert "/users" in schema["paths"] + + +def test_openapi_merge_add_schema(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="Add Schema API", version="1.0.0") + + # WHEN adding a schema manually + merge.add_schema( + { + "paths": {"/external": {"get": {"summary": "External endpoint"}}}, + }, + ) + + # THEN it should be included in the merged schema + schema = merge.get_openapi_schema() + assert "/external" in schema["paths"] + + +def test_openapi_merge_tags_from_schema(): + # GIVEN an OpenAPIMerge without config tags + merge = OpenAPIMerge(title="Tags API", version="1.0.0") + + # WHEN discovering a handler that has tags in its schema + merge.discover(path=MERGE_HANDLERS_PATH, pattern="**/tagged_handler.py") + + # THEN schema tags should include tags from discovered handler + schema = merge.get_openapi_schema() + tag_names = [t["name"] for t in schema.get("tags", [])] + assert "handler-tag" in tag_names + + +def test_openapi_merge_schema_is_cached(): + # GIVEN an OpenAPIMerge with discovered files + merge = OpenAPIMerge(title="Cached API", version="1.0.0") + merge.discover(path=MERGE_HANDLERS_PATH, pattern="**/users_handler.py") + + # WHEN calling get_openapi_schema multiple times + schema1 = merge.get_openapi_schema() + schema2 = merge.get_openapi_schema() + + # THEN it should return the same cached object + assert schema1 is schema2 + + # AND paths should not be duplicated + assert len([p for p in schema1["paths"] if p == "/users"]) == 1 diff --git a/tests/unit/event_handler/openapi/test_openapi_merge.py b/tests/unit/event_handler/openapi/test_openapi_merge.py new file mode 100644 index 00000000000..21500145b35 --- /dev/null +++ b/tests/unit/event_handler/openapi/test_openapi_merge.py @@ -0,0 +1,98 @@ +"""Unit tests for OpenAPI merge internal functions.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from aws_lambda_powertools.event_handler.openapi.merge import ( + _discover_resolver_files, + _file_has_resolver, + _is_excluded, + _load_resolver, +) + +MERGE_HANDLERS_PATH = Path(__file__).parents[3] / "functional/event_handler/_pydantic/merge_handlers" + + +def test_discover_resolver_files_path_not_exists(): + with pytest.raises(FileNotFoundError, match="Path does not exist"): + _discover_resolver_files("/non/existent/path", "**/*.py", [], "app") + + +def test_discover_resolver_files_multiple_patterns(): + files = _discover_resolver_files( + MERGE_HANDLERS_PATH, + ["**/users_handler.py", "**/orders_handler.py"], + [], + "app", + ) + filenames = {f.name for f in files} + assert "users_handler.py" in filenames + assert "orders_handler.py" in filenames + + +def test_file_has_resolver_syntax_error(tmp_path: Path): + bad_file = tmp_path / "bad.py" + bad_file.write_text("def broken(") + assert _file_has_resolver(bad_file, "app") is False + + +def test_file_has_resolver_wrong_variable_name(tmp_path: Path): + handler_file = tmp_path / "handler.py" + handler_file.write_text(""" +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +router = APIGatewayRestResolver() +""") + assert _file_has_resolver(handler_file, "app") is False + + +def test_file_has_resolver_found(tmp_path: Path): + handler_file = tmp_path / "handler.py" + handler_file.write_text(""" +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +app = APIGatewayRestResolver() +""") + assert _file_has_resolver(handler_file, "app") is True + + +def test_is_excluded_with_directory_pattern(): + root = Path("/project") + assert _is_excluded(Path("/project/tests/handler.py"), root, ["**/tests/**"]) is True + assert _is_excluded(Path("/project/src/handler.py"), root, ["**/tests/**"]) is False + + +def test_is_excluded_with_file_pattern(): + root = Path("/project") + assert _is_excluded(Path("/project/src/test_handler.py"), root, ["**/test_*.py"]) is True + assert _is_excluded(Path("/project/src/handler.py"), root, ["**/test_*.py"]) is False + + +def test_load_resolver_file_not_found(): + with pytest.raises(FileNotFoundError): + _load_resolver(Path("/non/existent/file.py"), "app") + + +def test_load_resolver_not_found_in_module(tmp_path: Path): + handler_file = tmp_path / "handler.py" + handler_file.write_text("x = 1") + + with pytest.raises(AttributeError, match="Resolver 'app' not found"): + _load_resolver(handler_file, "app") + + +def test_load_resolver_success(tmp_path: Path): + handler_file = tmp_path / "handler.py" + handler_file.write_text(""" +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +app = APIGatewayRestResolver() + +@app.get("/test") +def test_endpoint(): + return {"test": True} +""") + + resolver = _load_resolver(handler_file, "app") + assert resolver is not None + assert hasattr(resolver, "get_openapi_schema")