Skip to content

Commit 93fddc4

Browse files
committed
test(swagger): validate API requests against Swagger in integ tests
Validate that all requests we do to the API are conformant to the swagger, and that all responses are also conformant. This check is strict, meaning no extra fields are allowed to catch problems where there is a typo in the swagger. We only check successful requests as we don't want to fail when we try to send a bad request on purpose. If the request is successful, then it means the schema should have been valid. Signed-off-by: Riccardo Mancini <[email protected]>
1 parent 3e2e256 commit 93fddc4

File tree

3 files changed

+235
-1
lines changed

3 files changed

+235
-1
lines changed

tests/framework/http_api.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import requests
1010
from requests_unixsocket import DEFAULT_SCHEME, UnixAdapter
1111

12+
from framework.swagger_validator import SwaggerValidator, ValidationError
13+
1214

1315
class Session(requests.Session):
1416
"""An HTTP over UNIX sockets Session
@@ -65,6 +67,21 @@ def get(self):
6567
self._api.error_callback("GET", self.resource, str(e))
6668
raise
6769
assert res.status_code == HTTPStatus.OK, res.json()
70+
71+
# Validate response against Swagger specification
72+
# only validate successful requests
73+
if self._api.validator and res.status_code == HTTPStatus.OK:
74+
try:
75+
response_body = res.json()
76+
self._api.validator.validate_response(
77+
"GET", self.resource, 200, response_body
78+
)
79+
except ValidationError as e:
80+
# Re-raise with more context
81+
raise ValidationError(
82+
f"Response validation failed for GET {self.resource}: {e.message}"
83+
) from e
84+
6885
return res
6986

7087
def request(self, method, path, **kwargs):
@@ -85,6 +102,32 @@ def request(self, method, path, **kwargs):
85102
elif "error" in json:
86103
msg = json["error"]
87104
raise RuntimeError(msg, json, res)
105+
106+
# Validate request against Swagger specification
107+
# do this after the actual request as we only want to validate successful
108+
# requests as the tests may be trying to pass bad requests and assert an
109+
# error is raised.
110+
if self._api.validator:
111+
if kwargs:
112+
try:
113+
self._api.validator.validate_request(method, path, kwargs)
114+
except ValidationError as e:
115+
# Re-raise with more context
116+
raise ValidationError(
117+
f"Request validation failed for {method} {path}: {e.message}"
118+
) from e
119+
120+
if res.status_code == HTTPStatus.OK:
121+
try:
122+
response_body = res.json()
123+
self._api.validator.validate_response(
124+
method, path, 200, response_body
125+
)
126+
except ValidationError as e:
127+
# Re-raise with more context
128+
raise ValidationError(
129+
f"Response validation failed for {method} {path}: {e.message}"
130+
) from e
88131
return res
89132

90133
def put(self, **kwargs):
@@ -105,13 +148,16 @@ def patch(self, **kwargs):
105148
class Api:
106149
"""A simple HTTP client for the Firecracker API"""
107150

108-
def __init__(self, api_usocket_full_name, *, on_error=None):
151+
def __init__(self, api_usocket_full_name, *, validate=True, on_error=None):
109152
self.error_callback = on_error
110153
self.socket = api_usocket_full_name
111154
url_encoded_path = urllib.parse.quote_plus(api_usocket_full_name)
112155
self.endpoint = DEFAULT_SCHEME + url_encoded_path
113156
self.session = Session()
114157

158+
# Initialize the swagger validator
159+
self.validator = SwaggerValidator() if validate else None
160+
115161
self.describe = Resource(self, "/")
116162
self.vm = Resource(self, "/vm")
117163
self.vm_config = Resource(self, "/vm/config")

tests/framework/microvm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,13 +634,15 @@ def spawn(
634634
log_show_origin=False,
635635
metrics_path="fc.ndjson",
636636
emit_metrics: bool = False,
637+
validate_api: bool = True,
637638
):
638639
"""Start a microVM as a daemon or in a screen session."""
639640
# pylint: disable=subprocess-run-check
640641
# pylint: disable=too-many-branches
641642
self.jailer.setup()
642643
self.api = Api(
643644
self.jailer.api_socket_path(),
645+
validate=validate_api,
644646
on_error=lambda verb, uri, err_msg: self._dump_debug_information(
645647
f"Error during {verb} {uri}: {err_msg}"
646648
),
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""A validator for Firecracker API Swagger schema"""
5+
6+
from pathlib import Path
7+
8+
import yaml
9+
from jsonschema import Draft4Validator, ValidationError
10+
11+
12+
def _filter_none_recursive(data):
13+
if isinstance(data, dict):
14+
return {k: _filter_none_recursive(v) for k, v in data.items() if v is not None}
15+
if isinstance(data, list):
16+
return [_filter_none_recursive(item) for item in data if item is not None]
17+
return data
18+
19+
20+
class SwaggerValidator:
21+
"""Validator for API requests against the Swagger/OpenAPI specification"""
22+
23+
_instance = None
24+
_initialized = False
25+
26+
def __new__(cls):
27+
if cls._instance is None:
28+
cls._instance = super().__new__(cls)
29+
return cls._instance
30+
31+
def __init__(self):
32+
"""Initialize the validator with the Swagger specification."""
33+
if self._initialized:
34+
return
35+
self._initialized = True
36+
37+
swagger_path = (
38+
Path(__file__).parent.parent.parent
39+
/ "src"
40+
/ "firecracker"
41+
/ "swagger"
42+
/ "firecracker.yaml"
43+
)
44+
45+
with open(swagger_path, "r", encoding="utf-8") as f:
46+
self.swagger_spec = yaml.safe_load(f)
47+
48+
# Cache validators for each endpoint
49+
self._validators = {}
50+
self._build_validators()
51+
52+
def _build_validators(self):
53+
"""Build JSON schema validators for each endpoint."""
54+
paths = self.swagger_spec.get("paths", {})
55+
definitions = self.swagger_spec.get("definitions", {})
56+
57+
for path, methods in paths.items():
58+
for method, spec in methods.items():
59+
if method.upper() not in ["GET", "PUT", "PATCH", "POST", "DELETE"]:
60+
continue
61+
62+
# Build request body validators
63+
parameters = spec.get("parameters", [])
64+
for param in parameters:
65+
if param.get("in") == "body" and "schema" in param:
66+
schema = self._resolve_schema(param["schema"], definitions)
67+
if method.upper() == "PATCH":
68+
# do not validate required fields on PATCH requests
69+
schema["required"] = []
70+
key = ("request", method.upper(), path)
71+
self._validators[key] = Draft4Validator(schema)
72+
73+
# Build response validators for 200/204 responses
74+
responses = spec.get("responses", {})
75+
for status_code, response_spec in responses.items():
76+
if str(status_code) in ["200", "204"] and "schema" in response_spec:
77+
schema = self._resolve_schema(
78+
response_spec["schema"], definitions
79+
)
80+
key = ("response", method.upper(), path, str(status_code))
81+
self._validators[key] = Draft4Validator(schema)
82+
83+
def _resolve_schema(self, schema, definitions):
84+
"""Resolve $ref references in schema."""
85+
if "$ref" in schema:
86+
ref_path = schema["$ref"]
87+
if ref_path.startswith("#/definitions/"):
88+
def_name = ref_path.split("/")[-1]
89+
if def_name in definitions:
90+
return self._resolve_schema(definitions[def_name], definitions)
91+
92+
# Recursively resolve nested schemas
93+
resolved = schema.copy()
94+
if "properties" in resolved:
95+
resolved["properties"] = {
96+
k: self._resolve_schema(v, definitions)
97+
for k, v in resolved["properties"].items()
98+
}
99+
if "items" in resolved and isinstance(resolved["items"], dict):
100+
resolved["items"] = self._resolve_schema(resolved["items"], definitions)
101+
102+
if not "additionalProperties" in resolved:
103+
resolved["additionalProperties"] = False
104+
105+
return resolved
106+
107+
def validate_request(self, method, path, body):
108+
"""
109+
Validate a request body against the Swagger specification.
110+
111+
Args:
112+
method: HTTP method (GET, PUT, PATCH, etc.)
113+
path: API path (e.g., "/drives/{drive_id}")
114+
body: Request body as a dictionary
115+
116+
Raises:
117+
ValidationError: If the request body doesn't match the schema
118+
"""
119+
# Normalize path - replace specific IDs with parameter placeholders
120+
normalized_path = self._normalize_path(path)
121+
key = ("request", method.upper(), normalized_path)
122+
123+
if key in self._validators:
124+
validator = self._validators[key]
125+
# Remove None values from body before validation
126+
cleaned_body = _filter_none_recursive(body)
127+
validator.validate(cleaned_body)
128+
else:
129+
raise ValidationError(f"{key} is not in the schema")
130+
131+
def validate_response(self, method, path, status_code, body):
132+
"""
133+
Validate a response body against the Swagger specification.
134+
135+
Args:
136+
method: HTTP method (GET, PUT, PATCH, etc.)
137+
path: API path (e.g., "/drives/{drive_id}")
138+
status_code: HTTP status code (e.g., 200, 204)
139+
body: Response body as a dictionary
140+
141+
Raises:
142+
ValidationError: If the response body doesn't match the schema
143+
"""
144+
# Normalize path - replace specific IDs with parameter placeholders
145+
normalized_path = self._normalize_path(path)
146+
key = ("response", method.upper(), normalized_path, str(status_code))
147+
148+
if key in self._validators:
149+
validator = self._validators[key]
150+
# Remove None values from body before validation
151+
cleaned_body = _filter_none_recursive(body)
152+
validator.validate(cleaned_body)
153+
else:
154+
raise ValidationError(f"{key} is not in the schema")
155+
156+
def _normalize_path(self, path):
157+
"""
158+
Normalize a path by replacing specific IDs with parameter placeholders.
159+
160+
E.g., "/drives/rootfs" -> "/drives/{drive_id}"
161+
"""
162+
# Match against known patterns in the swagger spec
163+
paths = self.swagger_spec.get("paths", {})
164+
165+
# Direct match
166+
if path in paths:
167+
return path
168+
169+
# Try to match parameterized paths
170+
parts = path.split("/")
171+
for swagger_path in paths.keys():
172+
swagger_parts = swagger_path.split("/")
173+
if len(parts) == len(swagger_parts):
174+
match = True
175+
for _, (part, swagger_part) in enumerate(zip(parts, swagger_parts)):
176+
# Check if it's a parameter placeholder or exact match
177+
if swagger_part.startswith("{") and swagger_part.endswith("}"):
178+
continue # This is a parameter, any value matches
179+
if part != swagger_part:
180+
match = False
181+
break
182+
183+
if match:
184+
return swagger_path
185+
186+
return path

0 commit comments

Comments
 (0)