|
| 1 | +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""A validator for Firecracker API Swagger schema""" |
| 5 | + |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import yaml |
| 9 | +from jsonschema import Draft4Validator, ValidationError |
| 10 | + |
| 11 | + |
| 12 | +def _filter_none_recursive(data): |
| 13 | + if isinstance(data, dict): |
| 14 | + return {k: _filter_none_recursive(v) for k, v in data.items() if v is not None} |
| 15 | + if isinstance(data, list): |
| 16 | + return [_filter_none_recursive(item) for item in data if item is not None] |
| 17 | + return data |
| 18 | + |
| 19 | + |
| 20 | +class SwaggerValidator: |
| 21 | + """Validator for API requests against the Swagger/OpenAPI specification""" |
| 22 | + |
| 23 | + _instance = None |
| 24 | + _initialized = False |
| 25 | + |
| 26 | + def __new__(cls): |
| 27 | + if cls._instance is None: |
| 28 | + cls._instance = super().__new__(cls) |
| 29 | + return cls._instance |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + """Initialize the validator with the Swagger specification.""" |
| 33 | + if self._initialized: |
| 34 | + return |
| 35 | + self._initialized = True |
| 36 | + |
| 37 | + swagger_path = ( |
| 38 | + Path(__file__).parent.parent.parent |
| 39 | + / "src" |
| 40 | + / "firecracker" |
| 41 | + / "swagger" |
| 42 | + / "firecracker.yaml" |
| 43 | + ) |
| 44 | + |
| 45 | + with open(swagger_path, "r", encoding="utf-8") as f: |
| 46 | + self.swagger_spec = yaml.safe_load(f) |
| 47 | + |
| 48 | + # Cache validators for each endpoint |
| 49 | + self._validators = {} |
| 50 | + self._build_validators() |
| 51 | + |
| 52 | + def _build_validators(self): |
| 53 | + """Build JSON schema validators for each endpoint.""" |
| 54 | + paths = self.swagger_spec.get("paths", {}) |
| 55 | + definitions = self.swagger_spec.get("definitions", {}) |
| 56 | + |
| 57 | + for path, methods in paths.items(): |
| 58 | + for method, spec in methods.items(): |
| 59 | + if method.upper() not in ["GET", "PUT", "PATCH", "POST", "DELETE"]: |
| 60 | + continue |
| 61 | + |
| 62 | + # Build request body validators |
| 63 | + parameters = spec.get("parameters", []) |
| 64 | + for param in parameters: |
| 65 | + if param.get("in") == "body" and "schema" in param: |
| 66 | + schema = self._resolve_schema(param["schema"], definitions) |
| 67 | + if method.upper() == "PATCH": |
| 68 | + # do not validate required fields on PATCH requests |
| 69 | + schema["required"] = [] |
| 70 | + key = ("request", method.upper(), path) |
| 71 | + self._validators[key] = Draft4Validator(schema) |
| 72 | + |
| 73 | + # Build response validators for 200/204 responses |
| 74 | + responses = spec.get("responses", {}) |
| 75 | + for status_code, response_spec in responses.items(): |
| 76 | + if str(status_code) in ["200", "204"] and "schema" in response_spec: |
| 77 | + schema = self._resolve_schema( |
| 78 | + response_spec["schema"], definitions |
| 79 | + ) |
| 80 | + key = ("response", method.upper(), path, str(status_code)) |
| 81 | + self._validators[key] = Draft4Validator(schema) |
| 82 | + |
| 83 | + def _resolve_schema(self, schema, definitions): |
| 84 | + """Resolve $ref references in schema.""" |
| 85 | + if "$ref" in schema: |
| 86 | + ref_path = schema["$ref"] |
| 87 | + if ref_path.startswith("#/definitions/"): |
| 88 | + def_name = ref_path.split("/")[-1] |
| 89 | + if def_name in definitions: |
| 90 | + return self._resolve_schema(definitions[def_name], definitions) |
| 91 | + |
| 92 | + # Recursively resolve nested schemas |
| 93 | + resolved = schema.copy() |
| 94 | + if "properties" in resolved: |
| 95 | + resolved["properties"] = { |
| 96 | + k: self._resolve_schema(v, definitions) |
| 97 | + for k, v in resolved["properties"].items() |
| 98 | + } |
| 99 | + if "items" in resolved and isinstance(resolved["items"], dict): |
| 100 | + resolved["items"] = self._resolve_schema(resolved["items"], definitions) |
| 101 | + |
| 102 | + if not "additionalProperties" in resolved: |
| 103 | + resolved["additionalProperties"] = False |
| 104 | + |
| 105 | + return resolved |
| 106 | + |
| 107 | + def validate_request(self, method, path, body): |
| 108 | + """ |
| 109 | + Validate a request body against the Swagger specification. |
| 110 | +
|
| 111 | + Args: |
| 112 | + method: HTTP method (GET, PUT, PATCH, etc.) |
| 113 | + path: API path (e.g., "/drives/{drive_id}") |
| 114 | + body: Request body as a dictionary |
| 115 | +
|
| 116 | + Raises: |
| 117 | + ValidationError: If the request body doesn't match the schema |
| 118 | + """ |
| 119 | + # Normalize path - replace specific IDs with parameter placeholders |
| 120 | + normalized_path = self._normalize_path(path) |
| 121 | + key = ("request", method.upper(), normalized_path) |
| 122 | + |
| 123 | + if key in self._validators: |
| 124 | + validator = self._validators[key] |
| 125 | + # Remove None values from body before validation |
| 126 | + cleaned_body = _filter_none_recursive(body) |
| 127 | + validator.validate(cleaned_body) |
| 128 | + else: |
| 129 | + raise ValidationError(f"{key} is not in the schema") |
| 130 | + |
| 131 | + def validate_response(self, method, path, status_code, body): |
| 132 | + """ |
| 133 | + Validate a response body against the Swagger specification. |
| 134 | +
|
| 135 | + Args: |
| 136 | + method: HTTP method (GET, PUT, PATCH, etc.) |
| 137 | + path: API path (e.g., "/drives/{drive_id}") |
| 138 | + status_code: HTTP status code (e.g., 200, 204) |
| 139 | + body: Response body as a dictionary |
| 140 | +
|
| 141 | + Raises: |
| 142 | + ValidationError: If the response body doesn't match the schema |
| 143 | + """ |
| 144 | + # Normalize path - replace specific IDs with parameter placeholders |
| 145 | + normalized_path = self._normalize_path(path) |
| 146 | + key = ("response", method.upper(), normalized_path, str(status_code)) |
| 147 | + |
| 148 | + if key in self._validators: |
| 149 | + validator = self._validators[key] |
| 150 | + # Remove None values from body before validation |
| 151 | + cleaned_body = _filter_none_recursive(body) |
| 152 | + validator.validate(cleaned_body) |
| 153 | + else: |
| 154 | + raise ValidationError(f"{key} is not in the schema") |
| 155 | + |
| 156 | + def _normalize_path(self, path): |
| 157 | + """ |
| 158 | + Normalize a path by replacing specific IDs with parameter placeholders. |
| 159 | +
|
| 160 | + E.g., "/drives/rootfs" -> "/drives/{drive_id}" |
| 161 | + """ |
| 162 | + # Match against known patterns in the swagger spec |
| 163 | + paths = self.swagger_spec.get("paths", {}) |
| 164 | + |
| 165 | + # Direct match |
| 166 | + if path in paths: |
| 167 | + return path |
| 168 | + |
| 169 | + # Try to match parameterized paths |
| 170 | + parts = path.split("/") |
| 171 | + for swagger_path in paths.keys(): |
| 172 | + swagger_parts = swagger_path.split("/") |
| 173 | + if len(parts) == len(swagger_parts): |
| 174 | + match = True |
| 175 | + for _, (part, swagger_part) in enumerate(zip(parts, swagger_parts)): |
| 176 | + # Check if it's a parameter placeholder or exact match |
| 177 | + if swagger_part.startswith("{") and swagger_part.endswith("}"): |
| 178 | + continue # This is a parameter, any value matches |
| 179 | + if part != swagger_part: |
| 180 | + match = False |
| 181 | + break |
| 182 | + |
| 183 | + if match: |
| 184 | + return swagger_path |
| 185 | + |
| 186 | + return path |
0 commit comments