|
35 | 35 | CONTENT_DISPOSITION_NAME_PARAM = "name="
|
36 | 36 | APPLICATION_JSON_CONTENT_TYPE = "application/json"
|
37 | 37 | APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"
|
| 38 | +MULTIPART_FORM_CONTENT_TYPE = "multipart/form-data" |
38 | 39 |
|
39 | 40 |
|
40 | 41 | class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler):
|
@@ -125,8 +126,12 @@ def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
|
125 | 126 | elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
|
126 | 127 | return self._parse_form_data(app)
|
127 | 128 |
|
| 129 | + # Handle multipart form data |
| 130 | + elif content_type.startswith(MULTIPART_FORM_CONTENT_TYPE): |
| 131 | + return self._parse_multipart_data(app, content_type) |
| 132 | + |
128 | 133 | else:
|
129 |
| - raise NotImplementedError("Only JSON body or Form() are supported") |
| 134 | + raise NotImplementedError(f"Content type '{content_type}' is not supported") |
130 | 135 |
|
131 | 136 | def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
|
132 | 137 | """Parse JSON data from the request body."""
|
@@ -169,6 +174,91 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
|
169 | 174 | ],
|
170 | 175 | ) from e
|
171 | 176 |
|
| 177 | + def _parse_multipart_data(self, app: EventHandlerInstance, content_type: str) -> dict[str, Any]: |
| 178 | + """Parse multipart/form-data.""" |
| 179 | + import base64 |
| 180 | + import re |
| 181 | + |
| 182 | + try: |
| 183 | + # Get the raw body - it might be base64 encoded |
| 184 | + body = app.current_event.body or "" |
| 185 | + |
| 186 | + # Handle base64 encoded body (common in Lambda) |
| 187 | + if app.current_event.is_base64_encoded: |
| 188 | + try: |
| 189 | + decoded_bytes = base64.b64decode(body) |
| 190 | + except Exception: |
| 191 | + # If decoding fails, use body as-is |
| 192 | + decoded_bytes = body.encode("utf-8") if isinstance(body, str) else body |
| 193 | + else: |
| 194 | + decoded_bytes = body.encode("utf-8") if isinstance(body, str) else body |
| 195 | + |
| 196 | + # Extract boundary from content type - handle both standard and WebKit boundaries |
| 197 | + boundary_match = re.search(r"boundary=([^;,\s]+)", content_type) |
| 198 | + if not boundary_match: |
| 199 | + # Handle WebKit browsers that may use different boundary formats |
| 200 | + webkit_match = re.search(r"WebKitFormBoundary([a-zA-Z0-9]+)", content_type) |
| 201 | + if webkit_match: |
| 202 | + boundary = "WebKitFormBoundary" + webkit_match.group(1) |
| 203 | + else: |
| 204 | + raise ValueError("No boundary found in multipart content-type") |
| 205 | + else: |
| 206 | + boundary = boundary_match.group(1).strip('"') |
| 207 | + boundary_bytes = ("--" + boundary).encode("utf-8") |
| 208 | + |
| 209 | + # Parse multipart sections |
| 210 | + parsed_data: dict[str, Any] = {} |
| 211 | + if decoded_bytes: |
| 212 | + sections = decoded_bytes.split(boundary_bytes) |
| 213 | + |
| 214 | + for section in sections[1:-1]: # Skip first empty and last closing parts |
| 215 | + if not section.strip(): |
| 216 | + continue |
| 217 | + |
| 218 | + # Split headers and content |
| 219 | + header_end = section.find(b"\r\n\r\n") |
| 220 | + if header_end == -1: |
| 221 | + header_end = section.find(b"\n\n") |
| 222 | + if header_end == -1: |
| 223 | + continue |
| 224 | + content = section[header_end + 2 :].strip() |
| 225 | + else: |
| 226 | + content = section[header_end + 4 :].strip() |
| 227 | + |
| 228 | + headers_part = section[:header_end].decode("utf-8", errors="ignore") |
| 229 | + |
| 230 | + # Extract field name from Content-Disposition header |
| 231 | + name_match = re.search(r'name="([^"]+)"', headers_part) |
| 232 | + if name_match: |
| 233 | + field_name = name_match.group(1) |
| 234 | + |
| 235 | + # Check if it's a file field |
| 236 | + if "filename=" in headers_part: |
| 237 | + # It's a file - store as bytes |
| 238 | + parsed_data[field_name] = content |
| 239 | + else: |
| 240 | + # It's a regular form field - decode as string |
| 241 | + try: |
| 242 | + parsed_data[field_name] = content.decode("utf-8") |
| 243 | + except UnicodeDecodeError: |
| 244 | + # If can't decode as text, keep as bytes |
| 245 | + parsed_data[field_name] = content |
| 246 | + |
| 247 | + return parsed_data |
| 248 | + |
| 249 | + except Exception as e: |
| 250 | + raise RequestValidationError( |
| 251 | + [ |
| 252 | + { |
| 253 | + "type": "multipart_invalid", |
| 254 | + "loc": ("body",), |
| 255 | + "msg": "Invalid multipart form data", |
| 256 | + "input": {}, |
| 257 | + "ctx": {"error": str(e)}, |
| 258 | + }, |
| 259 | + ] |
| 260 | + ) from e |
| 261 | + |
172 | 262 |
|
173 | 263 | class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler):
|
174 | 264 | """
|
|
0 commit comments