Skip to content

Commit 3f2089e

Browse files
committed
test(swagger): validate API requests against Swagger in integ tests
Validate that all requests we do to the API are conformant to the swagger, and that all responses are also conformant. This check is strict, meaning no extra fields are allowed to catch problems where there is a typo in the swagger. We only check successful requests as we don't want to fail when we try to send a bad request on purpose. If the request is successful, then it means the schema should have been valid. Signed-off-by: Riccardo Mancini <[email protected]>
1 parent 2935e72 commit 3f2089e

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed

tests/framework/http_api.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import requests
1010
from requests_unixsocket import DEFAULT_SCHEME, UnixAdapter
1111

12+
from framework.swagger_validator import SwaggerValidator, ValidationError
13+
1214

1315
class Session(requests.Session):
1416
"""An HTTP over UNIX sockets Session
@@ -65,6 +67,21 @@ def get(self):
6567
self._api.error_callback("GET", self.resource, str(e))
6668
raise
6769
assert res.status_code == HTTPStatus.OK, res.json()
70+
71+
# Validate response against Swagger specification
72+
# only validate successful requests
73+
if res.status_code == HTTPStatus.OK:
74+
try:
75+
response_body = res.json()
76+
self._api.validator.validate_response(
77+
"GET", self.resource, 200, response_body
78+
)
79+
except ValidationError as e:
80+
# Re-raise with more context
81+
raise ValidationError(
82+
f"Response validation failed for GET {self.resource}: {e.message}"
83+
) from e
84+
6885
return res
6986

7087
def request(self, method, path, **kwargs):
@@ -85,6 +102,29 @@ def request(self, method, path, **kwargs):
85102
elif "error" in json:
86103
msg = json["error"]
87104
raise RuntimeError(msg, json, res)
105+
106+
# Validate request against Swagger specification
107+
# do this after the actual request as we only want to validate successful
108+
# requests as the tests may be trying to pass bad requests and assert an
109+
# error is raised.
110+
if kwargs:
111+
try:
112+
self._api.validator.validate_request(method, path, kwargs)
113+
except ValidationError as e:
114+
# Re-raise with more context
115+
raise ValidationError(
116+
f"Request validation failed for {method} {path}: {e.message}"
117+
) from e
118+
119+
if res.status_code == HTTPStatus.OK:
120+
try:
121+
response_body = res.json()
122+
self._api.validator.validate_response(method, path, 200, response_body)
123+
except ValidationError as e:
124+
# Re-raise with more context
125+
raise ValidationError(
126+
f"Response validation failed for {method} {path}: {e.message}"
127+
) from e
88128
return res
89129

90130
def put(self, **kwargs):
@@ -112,6 +152,9 @@ def __init__(self, api_usocket_full_name, *, on_error=None):
112152
self.endpoint = DEFAULT_SCHEME + url_encoded_path
113153
self.session = Session()
114154

155+
# Initialize the swagger validator
156+
self.validator = SwaggerValidator()
157+
115158
self.describe = Resource(self, "/")
116159
self.vm = Resource(self, "/vm")
117160
self.vm_config = Resource(self, "/vm/config")
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""A validator for Firecracker API Swagger schema"""
5+
6+
from pathlib import Path
7+
8+
import yaml
9+
from jsonschema import Draft4Validator, ValidationError
10+
11+
12+
def _filter_none_recursive(data):
13+
if isinstance(data, dict):
14+
return {k: _filter_none_recursive(v) for k, v in data.items() if v is not None}
15+
if isinstance(data, list):
16+
return [_filter_none_recursive(item) for item in data if item is not None]
17+
return data
18+
19+
20+
class SwaggerValidator:
21+
"""Validator for API requests against the Swagger/OpenAPI specification"""
22+
23+
_instance = None
24+
_initialized = False
25+
26+
def __new__(cls):
27+
if cls._instance is None:
28+
cls._instance = super().__new__(cls)
29+
return cls._instance
30+
31+
def __init__(self):
32+
"""Initialize the validator with the Swagger specification."""
33+
if self._initialized:
34+
return
35+
self._initialized = True
36+
37+
swagger_path = (
38+
Path(__file__).parent.parent.parent
39+
/ "src"
40+
/ "firecracker"
41+
/ "swagger"
42+
/ "firecracker.yaml"
43+
)
44+
45+
with open(swagger_path, "r", encoding="utf-8") as f:
46+
self.swagger_spec = yaml.safe_load(f)
47+
48+
# Cache validators for each endpoint
49+
self._validators = {}
50+
self._build_validators()
51+
52+
def _build_validators(self):
53+
"""Build JSON schema validators for each endpoint."""
54+
paths = self.swagger_spec.get("paths", {})
55+
definitions = self.swagger_spec.get("definitions", {})
56+
57+
for path, methods in paths.items():
58+
for method, spec in methods.items():
59+
if method.upper() not in ["GET", "PUT", "PATCH", "POST", "DELETE"]:
60+
continue
61+
62+
# Build request body validators
63+
parameters = spec.get("parameters", [])
64+
for param in parameters:
65+
if param.get("in") == "body" and "schema" in param:
66+
schema = self._resolve_schema(param["schema"], definitions)
67+
if method.upper() == "PATCH":
68+
# do not validate required fields on PATCH requests
69+
schema["required"] = []
70+
key = ("request", method.upper(), path)
71+
self._validators[key] = Draft4Validator(schema)
72+
73+
# Build response validators for 200/204 responses
74+
responses = spec.get("responses", {})
75+
for status_code, response_spec in responses.items():
76+
if str(status_code) in ["200", "204"] and "schema" in response_spec:
77+
schema = self._resolve_schema(
78+
response_spec["schema"], definitions
79+
)
80+
key = ("response", method.upper(), path, str(status_code))
81+
self._validators[key] = Draft4Validator(schema)
82+
83+
def _resolve_schema(self, schema, definitions):
84+
"""Resolve $ref references in schema."""
85+
if "$ref" in schema:
86+
ref_path = schema["$ref"]
87+
if ref_path.startswith("#/definitions/"):
88+
def_name = ref_path.split("/")[-1]
89+
if def_name in definitions:
90+
return self._resolve_schema(definitions[def_name], definitions)
91+
92+
# Recursively resolve nested schemas
93+
resolved = schema.copy()
94+
if "properties" in resolved:
95+
resolved["properties"] = {
96+
k: self._resolve_schema(v, definitions)
97+
for k, v in resolved["properties"].items()
98+
}
99+
if "items" in resolved and isinstance(resolved["items"], dict):
100+
resolved["items"] = self._resolve_schema(resolved["items"], definitions)
101+
102+
if not "additionalProperties" in resolved:
103+
resolved["additionalProperties"] = False
104+
105+
return resolved
106+
107+
def validate_request(self, method, path, body):
108+
"""
109+
Validate a request body against the Swagger specification.
110+
111+
Args:
112+
method: HTTP method (GET, PUT, PATCH, etc.)
113+
path: API path (e.g., "/drives/{drive_id}")
114+
body: Request body as a dictionary
115+
116+
Raises:
117+
ValidationError: If the request body doesn't match the schema
118+
"""
119+
# Normalize path - replace specific IDs with parameter placeholders
120+
normalized_path = self._normalize_path(path)
121+
key = ("request", method.upper(), normalized_path)
122+
123+
if key in self._validators:
124+
validator = self._validators[key]
125+
# Remove None values from body before validation
126+
cleaned_body = _filter_none_recursive(body)
127+
validator.validate(cleaned_body)
128+
else:
129+
raise ValidationError(f"{key} is not in the schema")
130+
131+
def validate_response(self, method, path, status_code, body):
132+
"""
133+
Validate a response body against the Swagger specification.
134+
135+
Args:
136+
method: HTTP method (GET, PUT, PATCH, etc.)
137+
path: API path (e.g., "/drives/{drive_id}")
138+
status_code: HTTP status code (e.g., 200, 204)
139+
body: Response body as a dictionary
140+
141+
Raises:
142+
ValidationError: If the response body doesn't match the schema
143+
"""
144+
# Normalize path - replace specific IDs with parameter placeholders
145+
normalized_path = self._normalize_path(path)
146+
key = ("response", method.upper(), normalized_path, str(status_code))
147+
148+
if key in self._validators:
149+
validator = self._validators[key]
150+
# Remove None values from body before validation
151+
cleaned_body = _filter_none_recursive(body)
152+
validator.validate(cleaned_body)
153+
else:
154+
raise ValidationError(f"{key} is not in the schema")
155+
156+
def _normalize_path(self, path):
157+
"""
158+
Normalize a path by replacing specific IDs with parameter placeholders.
159+
160+
E.g., "/drives/rootfs" -> "/drives/{drive_id}"
161+
"""
162+
# Match against known patterns in the swagger spec
163+
paths = self.swagger_spec.get("paths", {})
164+
165+
# Direct match
166+
if path in paths:
167+
return path
168+
169+
# Try to match parameterized paths
170+
parts = path.split("/")
171+
for swagger_path in paths.keys():
172+
swagger_parts = swagger_path.split("/")
173+
if len(parts) == len(swagger_parts):
174+
match = True
175+
for _, (part, swagger_part) in enumerate(zip(parts, swagger_parts)):
176+
# Check if it's a parameter placeholder or exact match
177+
if swagger_part.startswith("{") and swagger_part.endswith("}"):
178+
continue # This is a parameter, any value matches
179+
if part != swagger_part:
180+
match = False
181+
break
182+
183+
if match:
184+
return swagger_path
185+
186+
return path

0 commit comments

Comments
 (0)