@@ -288,299 +288,27 @@ def _prepare_response_content(
288
288
"""
289
289
Prepares the response content for serialization.
290
290
"""
291
- if isinstance (res , BaseModel ):
292
- return _model_dump (
291
+ if isinstance (res , BaseModel ): # pragma: no cover
292
+ return _model_dump ( # pragma: no cover
293
293
res ,
294
294
by_alias = True ,
295
295
exclude_unset = exclude_unset ,
296
296
exclude_defaults = exclude_defaults ,
297
297
exclude_none = exclude_none ,
298
298
)
299
- elif isinstance (res , list ):
300
- return [
299
+ elif isinstance (res , list ): # pragma: no cover
300
+ return [ # pragma: no cover
301
301
self ._prepare_response_content (item , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
302
302
for item in res
303
303
]
304
- elif isinstance (res , dict ):
305
- return {
304
+ elif isinstance (res , dict ): # pragma: no cover
305
+ return { # pragma: no cover
306
306
k : self ._prepare_response_content (v , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
307
307
for k , v in res .items ()
308
308
}
309
- elif dataclasses .is_dataclass (res ):
310
- return dataclasses .asdict (res ) # type: ignore[arg-type]
311
- return res
312
-
313
-
314
- class OpenAPIValidationMiddleware (BaseMiddlewareHandler ):
315
- """
316
- OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
317
- Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
318
- should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
319
-
320
- Example
321
- --------
322
-
323
- ```python
324
- from pydantic import BaseModel
325
-
326
- from aws_lambda_powertools.event_handler.api_gateway import (
327
- APIGatewayRestResolver,
328
- )
329
-
330
- class Todo(BaseModel):
331
- name: str
332
-
333
- app = APIGatewayRestResolver(enable_validation=True)
334
-
335
- @app.get("/todos")
336
- def get_todos(): list[Todo]:
337
- return [Todo(name="hello world")]
338
- ```
339
- """
340
-
341
- def __init__ (
342
- self ,
343
- validation_serializer : Callable [[Any ], str ] | None = None ,
344
- has_response_validation_error : bool = False ,
345
- ):
346
- """
347
- Initialize the OpenAPIValidationMiddleware.
348
-
349
- Parameters
350
- ----------
351
- validation_serializer : Callable, optional
352
- Optional serializer to use when serializing the response for validation.
353
- Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
354
-
355
- has_response_validation_error: bool, optional
356
- Optional flag used to distinguish between payload and validation errors.
357
- By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
358
- """
359
- self ._validation_serializer = validation_serializer
360
- self ._has_response_validation_error = has_response_validation_error
361
-
362
- def handler (self , app : EventHandlerInstance , next_middleware : NextMiddleware ) -> Response :
363
- logger .debug ("OpenAPIValidationMiddleware handler" )
364
-
365
- route : Route = app .context ["_route" ]
366
-
367
- values : dict [str , Any ] = {}
368
- errors : list [Any ] = []
369
-
370
- # Process path values, which can be found on the route_args
371
- path_values , path_errors = _request_params_to_args (
372
- route .dependant .path_params ,
373
- app .context ["_route_args" ],
374
- )
375
-
376
- # Normalize query values before validate this
377
- query_string = _normalize_multi_query_string_with_param (
378
- app .current_event .resolved_query_string_parameters ,
379
- route .dependant .query_params ,
380
- )
381
-
382
- # Process query values
383
- query_values , query_errors = _request_params_to_args (
384
- route .dependant .query_params ,
385
- query_string ,
386
- )
387
-
388
- # Normalize header values before validate this
389
- headers = _normalize_multi_header_values_with_param (
390
- app .current_event .resolved_headers_field ,
391
- route .dependant .header_params ,
392
- )
393
-
394
- # Process header values
395
- header_values , header_errors = _request_params_to_args (
396
- route .dependant .header_params ,
397
- headers ,
398
- )
399
-
400
- values .update (path_values )
401
- values .update (query_values )
402
- values .update (header_values )
403
- errors += path_errors + query_errors + header_errors
404
-
405
- # Process the request body, if it exists
406
- if route .dependant .body_params :
407
- (body_values , body_errors ) = _request_body_to_args (
408
- required_params = route .dependant .body_params ,
409
- received_body = self ._get_body (app ),
410
- )
411
- values .update (body_values )
412
- errors .extend (body_errors )
413
-
414
- if errors :
415
- # Raise the validation errors
416
- raise RequestValidationError (_normalize_errors (errors ))
417
- else :
418
- # Re-write the route_args with the validated values, and call the next middleware
419
- app .context ["_route_args" ] = values
420
-
421
- # Call the handler by calling the next middleware
422
- response = next_middleware (app )
423
-
424
- # Process the response
425
- return self ._handle_response (route = route , response = response )
426
-
427
- def _handle_response (self , * , route : Route , response : Response ):
428
- # Process the response body if it exists
429
- if response .body and response .is_json ():
430
- response .body = self ._serialize_response (
431
- field = route .dependant .return_param ,
432
- response_content = response .body ,
433
- has_route_custom_response_validation = route .custom_response_validation_http_code is not None ,
434
- )
435
-
436
- return response
437
-
438
- def _serialize_response (
439
- self ,
440
- * ,
441
- field : ModelField | None = None ,
442
- response_content : Any ,
443
- include : IncEx | None = None ,
444
- exclude : IncEx | None = None ,
445
- by_alias : bool = True ,
446
- exclude_unset : bool = False ,
447
- exclude_defaults : bool = False ,
448
- exclude_none : bool = False ,
449
- has_route_custom_response_validation : bool = False ,
450
- ) -> Any :
451
- """
452
- Serialize the response content according to the field type.
453
- """
454
- if field :
455
- errors : list [dict [str , Any ]] = []
456
- value = _validate_field (field = field , value = response_content , loc = ("response" ,), existing_errors = errors )
457
- if errors :
458
- # route-level validation must take precedence over app-level
459
- if has_route_custom_response_validation :
460
- raise ResponseValidationError (
461
- errors = _normalize_errors (errors ),
462
- body = response_content ,
463
- source = "route" ,
464
- )
465
- if self ._has_response_validation_error :
466
- raise ResponseValidationError (errors = _normalize_errors (errors ), body = response_content , source = "app" )
467
-
468
- raise RequestValidationError (errors = _normalize_errors (errors ), body = response_content )
469
-
470
- if hasattr (field , "serialize" ):
471
- return field .serialize (
472
- value ,
473
- include = include ,
474
- exclude = exclude ,
475
- by_alias = by_alias ,
476
- exclude_unset = exclude_unset ,
477
- exclude_defaults = exclude_defaults ,
478
- exclude_none = exclude_none ,
479
- )
480
- return jsonable_encoder (
481
- value ,
482
- include = include ,
483
- exclude = exclude ,
484
- by_alias = by_alias ,
485
- exclude_unset = exclude_unset ,
486
- exclude_defaults = exclude_defaults ,
487
- exclude_none = exclude_none ,
488
- custom_serializer = self ._validation_serializer ,
489
- )
490
- else :
491
- # Just serialize the response content returned from the handler.
492
- return jsonable_encoder (response_content , custom_serializer = self ._validation_serializer )
493
-
494
- def _prepare_response_content (
495
- self ,
496
- res : Any ,
497
- * ,
498
- exclude_unset : bool ,
499
- exclude_defaults : bool = False ,
500
- exclude_none : bool = False ,
501
- ) -> Any :
502
- """
503
- Prepares the response content for serialization.
504
- """
505
- if isinstance (res , BaseModel ):
506
- return _model_dump (
507
- res ,
508
- by_alias = True ,
509
- exclude_unset = exclude_unset ,
510
- exclude_defaults = exclude_defaults ,
511
- exclude_none = exclude_none ,
512
- )
513
- elif isinstance (res , list ):
514
- return [
515
- self ._prepare_response_content (item , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
516
- for item in res
517
- ]
518
- elif isinstance (res , dict ):
519
- return {
520
- k : self ._prepare_response_content (v , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
521
- for k , v in res .items ()
522
- }
523
- elif dataclasses .is_dataclass (res ):
524
- return dataclasses .asdict (res ) # type: ignore[arg-type]
525
- return res
526
-
527
- def _get_body (self , app : EventHandlerInstance ) -> dict [str , Any ]:
528
- """
529
- Get the request body from the event, and parse it according to content type.
530
- """
531
- content_type = app .current_event .headers .get ("content-type" , "" ).strip ()
532
-
533
- # Handle JSON content
534
- if not content_type or content_type .startswith (APPLICATION_JSON_CONTENT_TYPE ):
535
- return self ._parse_json_data (app )
536
-
537
- # Handle URL-encoded form data
538
- elif content_type .startswith (APPLICATION_FORM_CONTENT_TYPE ):
539
- return self ._parse_form_data (app )
540
-
541
- else :
542
- raise NotImplementedError ("Only JSON body or Form() are supported" )
543
-
544
- def _parse_json_data (self , app : EventHandlerInstance ) -> dict [str , Any ]:
545
- """Parse JSON data from the request body."""
546
- try :
547
- return app .current_event .json_body
548
- except json .JSONDecodeError as e :
549
- raise RequestValidationError (
550
- [
551
- {
552
- "type" : "json_invalid" ,
553
- "loc" : ("body" , e .pos ),
554
- "msg" : "JSON decode error" ,
555
- "input" : {},
556
- "ctx" : {"error" : e .msg },
557
- },
558
- ],
559
- body = e .doc ,
560
- ) from e
561
-
562
- def _parse_form_data (self , app : EventHandlerInstance ) -> dict [str , Any ]:
563
- """Parse URL-encoded form data from the request body."""
564
- try :
565
- body = app .current_event .decoded_body or ""
566
- # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
567
- parsed = parse_qs (body , keep_blank_values = True )
568
-
569
- result : dict [str , Any ] = {key : values [0 ] if len (values ) == 1 else values for key , values in parsed .items ()}
570
- return result
571
-
572
- except Exception as e : # pragma: no cover
573
- raise RequestValidationError ( # pragma: no cover
574
- [
575
- {
576
- "type" : "form_invalid" ,
577
- "loc" : ("body" ,),
578
- "msg" : "Form data parsing error" ,
579
- "input" : {},
580
- "ctx" : {"error" : str (e )},
581
- },
582
- ],
583
- ) from e
309
+ elif dataclasses .is_dataclass (res ): # pragma: no cover
310
+ return dataclasses .asdict (res ) # type: ignore[arg-type] # pragma: no cover
311
+ return res # pragma: no cover
584
312
585
313
586
314
def _request_params_to_args (
0 commit comments