Skip to content

Commit dcfb2e6

Browse files
Adding exception handler + resolvers
1 parent 8628bb9 commit dcfb2e6

File tree

6 files changed

+344
-62
lines changed

6 files changed

+344
-62
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing_extensions import override
1818

1919
from aws_lambda_powertools.event_handler import content_types
20+
from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager
2021
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
2122
from aws_lambda_powertools.event_handler.openapi.config import OpenAPIConfig
2223
from aws_lambda_powertools.event_handler.openapi.constants import (
@@ -1543,6 +1544,7 @@ def __init__(
15431544
self.processed_stack_frames = []
15441545
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
15451546
self.openapi_config = OpenAPIConfig() # starting an empty dataclass
1547+
self.exception_handler_manager = ExceptionHandlerManager()
15461548
self._has_response_validation_error = response_validation_error_http_code is not None
15471549
self._response_validation_error_http_code = self._validate_response_validation_error_http_code(
15481550
response_validation_error_http_code,
@@ -2409,7 +2411,7 @@ def not_found_handler():
24092411
return Response(status_code=204, content_type=None, headers=_headers, body="")
24102412

24112413
# Customer registered 404 route? Call it.
2412-
custom_not_found_handler = self._lookup_exception_handler(NotFoundError)
2414+
custom_not_found_handler = self.exception_handler_manager.lookup_exception_handler(NotFoundError)
24132415
if custom_not_found_handler:
24142416
return custom_not_found_handler(NotFoundError())
24152417

@@ -2482,26 +2484,10 @@ def not_found(self, func: Callable | None = None):
24822484
return self.exception_handler(NotFoundError)(func)
24832485

24842486
def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]):
2485-
def register_exception_handler(func: Callable):
2486-
if isinstance(exc_class, list): # pragma: no cover
2487-
for exp in exc_class:
2488-
self._exception_handlers[exp] = func
2489-
else:
2490-
self._exception_handlers[exc_class] = func
2491-
return func
2492-
2493-
return register_exception_handler
2494-
2495-
def _lookup_exception_handler(self, exp_type: type) -> Callable | None:
2496-
# Use "Method Resolution Order" to allow for matching against a base class
2497-
# of an exception
2498-
for cls in exp_type.__mro__:
2499-
if cls in self._exception_handlers:
2500-
return self._exception_handlers[cls]
2501-
return None
2487+
return self.exception_handler_manager.exception_handler(exc_class=exc_class)
25022488

25032489
def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuilder | None:
2504-
handler = self._lookup_exception_handler(type(exp))
2490+
handler = self.exception_handler_manager.lookup_exception_handler(type(exp))
25052491
if handler:
25062492
try:
25072493
return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route)
@@ -2595,7 +2581,7 @@ def include_router(self, router: Router, prefix: str | None = None) -> None:
25952581
self._router_middlewares = self._router_middlewares + router._router_middlewares
25962582

25972583
logger.debug("Appending Router exception_handler into App exception_handler.")
2598-
self._exception_handlers.update(router._exception_handlers)
2584+
self.exception_handler_manager.update_exception_handlers(router._exception_handlers)
25992585

26002586
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
26012587
router.context = self.context

aws_lambda_powertools/event_handler/appsync.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from typing import TYPE_CHECKING, Any, Callable
77

8+
from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager
89
from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError
910
from aws_lambda_powertools.event_handler.graphql_appsync.router import Router
1011
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
@@ -52,6 +53,7 @@ def __init__(self):
5253
Initialize a new instance of the AppSyncResolver.
5354
"""
5455
super().__init__()
56+
self.exception_handler_manager = ExceptionHandlerManager()
5557
self.context = {} # early init as customers might add context before event resolution
5658
self._exception_handlers: dict[type, Callable] = {}
5759

@@ -151,7 +153,7 @@ def lambda_handler(event, context):
151153
Router.current_event = data_model(event)
152154
response = self._call_single_resolver(event=event, data_model=data_model)
153155
except Exception as exp:
154-
response_builder = self._lookup_exception_handler(type(exp))
156+
response_builder = self.exception_handler_manager.lookup_exception_handler(type(exp))
155157
if response_builder:
156158
return response_builder(exp)
157159
raise
@@ -493,31 +495,4 @@ def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]):
493495
A decorator function that registers the exception handler.
494496
"""
495497

496-
def register_exception_handler(func: Callable):
497-
if isinstance(exc_class, list): # pragma: no cover
498-
for exp in exc_class:
499-
self._exception_handlers[exp] = func
500-
else:
501-
self._exception_handlers[exc_class] = func
502-
return func
503-
504-
return register_exception_handler
505-
506-
def _lookup_exception_handler(self, exp_type: type) -> Callable | None:
507-
"""
508-
Looks up the registered exception handler for the given exception type or its base classes.
509-
510-
Parameters
511-
----------
512-
exp_type (type):
513-
The exception type to look up the handler for.
514-
515-
Returns
516-
-------
517-
Callable | None:
518-
The registered exception handler function if found, otherwise None.
519-
"""
520-
for cls in exp_type.__mro__:
521-
if cls in self._exception_handlers:
522-
return self._exception_handlers[cls]
523-
return None
498+
return self.exception_handler_manager.exception_handler(exc_class=exc_class)
Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,137 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import logging
35
from typing import TYPE_CHECKING, Any, Callable
46

5-
from aws_lambda_powertools.event_handler.events_appsync._registry import ResolverEventsRegistry
7+
from aws_lambda_powertools.event_handler.events_appsync.router import Router
8+
from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager
9+
from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent
610

711
if TYPE_CHECKING:
8-
from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent
912
from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext
1013

11-
12-
class AppSyncEventsResolver:
14+
logger = logging.getLogger(__name__)
15+
class AppSyncEventsResolver(Router):
1316
"""
1417
AppSync Events API Resolver
1518
"""
1619

1720
def __init__(self):
21+
super().__init__()
1822
self.context = {} # early init as customers might add context before event resolution
19-
self._publish_registry = ResolverEventsRegistry()
20-
self._async_publish_registry = ResolverEventsRegistry()
21-
self._subscribe_registry = ResolverEventsRegistry()
22-
self._async_subscribe_registry = ResolverEventsRegistry()
23+
self.exception_handler_manager = ExceptionHandlerManager()
24+
self._exception_handlers: dict[type, Callable] = {}
25+
26+
def __call__(
27+
self,
28+
event: dict,
29+
context: LambdaContext,
30+
data_model: type[AppSyncResolverEventsEvent] = AppSyncResolverEventsEvent,
31+
) -> Any:
32+
"""Implicit lambda handler which internally calls `resolve`"""
33+
return self.resolve(event, context, data_model)
2334

2435
def resolve(
2536
self,
2637
event: AppSyncResolverEventsEvent,
2738
context: LambdaContext,
39+
data_model: type[AppSyncResolverEventsEvent] = AppSyncResolverEventsEvent,
2840
) -> Any:
2941
"""Resolves the response based on the provide event and decorator operation and namespaces"""
30-
print(self._publish_registry.__dict__)
3142

32-
def publish(
43+
self.lambda_context = context
44+
Router.lambda_context = context
45+
46+
Router.current_event = data_model(event)
47+
self.current_event = data_model(event)
48+
49+
try:
50+
if self.current_event.info.operation == "PUBLISH":
51+
response = self._call_publish_events(payload=self.current_event.events, data_model=data_model)
52+
else:
53+
response = self._call_subscribe_events(event=event, data_model=data_model)
54+
except Exception as exp:
55+
response_builder = self.exception_handler_manager.lookup_exception_handler(type(exp))
56+
if response_builder:
57+
return response_builder(exp)
58+
raise
59+
60+
self.clear_context()
61+
62+
return response
63+
64+
def _call_subscribe_events(self, payload: list[dict[str, Any]], data_model: type[AppSyncResolverEventsEvent]) -> Any:
65+
# PLACEHOLDER
66+
pass
67+
68+
def _call_publish_events(self, payload: list[dict[str, Any]], data_model: type[AppSyncResolverEventsEvent]) -> Any:
69+
"""Call single event resolver
70+
71+
Parameters
72+
----------
73+
event : dict
74+
Event
75+
data_model : type[AppSyncResolverEvent]
76+
Data_model to decode AppSync event, by default it is of AppSyncResolverEvent type or subclass of it
77+
"""
78+
79+
result = []
80+
logger.debug("Processing direct resolver event")
81+
82+
#self.current_event = data_model(event)
83+
resolver = self._publish_registry.find_resolver(self.current_event.info.channel_path)
84+
if not resolver:
85+
print(f"No resolver found for '{self.current_event.info.channel_path}'")
86+
print(resolver)
87+
88+
if not resolver["aggregate"]:
89+
return resolver["func"](payload=self.current_event.events)
90+
else:
91+
for i in self.current_event.events:
92+
result.append(resolver["func"](payload=i))
93+
94+
return result
95+
96+
def on_publish(
3397
self,
3498
path: str = "/default/*",
3599
aggregate: bool = True,
36100
) -> Callable:
37101
return self._publish_registry.register(path=path, aggregate=aggregate)
38102

39-
def async_publish(
103+
def async_on_publish(
40104
self,
41105
path: str = "/default/*",
42106
aggregate: bool = True,
43107
) -> Callable:
44108
return self._async_publish_registry.register(path=path, aggregate=aggregate)
45109

46-
def subscribe(
110+
def on_subscribe(
47111
self,
48112
path: str = "/default/*",
49113
) -> Callable:
50114
return self._subscribe_registry.register(path=path)
51115

52-
def async_subscribe(
116+
def async_on_subscribe(
53117
self,
54118
path: str = "/default/*",
55119
) -> Callable:
56120
return self._async_subscribe_registry.register(path=path)
121+
122+
def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]):
123+
"""
124+
A decorator function that registers a handler for one or more exception types.
125+
126+
Parameters
127+
----------
128+
exc_class (type[Exception] | list[type[Exception]])
129+
A single exception type or a list of exception types.
130+
131+
Returns
132+
-------
133+
Callable:
134+
A decorator function that registers the exception handler.
135+
"""
136+
137+
return self.exception_handler_manager.exception_handler(exc_class=exc_class)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Callable
5+
6+
7+
class BaseRouter(ABC):
8+
"""Abstract base class for Router (resolvers)"""
9+
10+
@abstractmethod
11+
def on_publish(
12+
self,
13+
path: str = "/default/*",
14+
aggregate: bool = True,
15+
) -> Callable:
16+
raise NotImplementedError
17+
18+
@abstractmethod
19+
def async_on_publish(
20+
self,
21+
path: str = "/default/*",
22+
aggregate: bool = True,
23+
) -> Callable:
24+
raise NotImplementedError
25+
26+
@abstractmethod
27+
def on_subscribe(
28+
self,
29+
path: str = "/default/*",
30+
) -> Callable:
31+
raise NotImplementedError
32+
33+
@abstractmethod
34+
def async_on_subscribe(
35+
self,
36+
path: str = "/default/*",
37+
) -> Callable:
38+
raise NotImplementedError
39+
@abstractmethod
40+
def append_context(self, **additional_context) -> None:
41+
"""
42+
Appends context information available under any route.
43+
44+
Parameters
45+
-----------
46+
**additional_context: dict
47+
Additional context key-value pairs to append.
48+
"""
49+
raise NotImplementedError
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Callable
4+
5+
from aws_lambda_powertools.event_handler.events_appsync._registry import ResolverEventsRegistry
6+
from aws_lambda_powertools.event_handler.events_appsync.base import BaseRouter
7+
from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent
8+
9+
if TYPE_CHECKING:
10+
from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent
11+
from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext
12+
13+
14+
class Router(BaseRouter):
15+
16+
context: dict
17+
current_event: AppSyncResolverEventsEvent | None = None
18+
lambda_context: LambdaContext | None = None
19+
20+
def __init__(self):
21+
self.context = {} # early init as customers might add context before event resolution
22+
self._publish_registry = ResolverEventsRegistry()
23+
self._async_publish_registry = ResolverEventsRegistry()
24+
self._subscribe_registry = ResolverEventsRegistry()
25+
self._async_subscribe_registry = ResolverEventsRegistry()
26+
27+
def on_publish(
28+
self,
29+
path: str = "/default/*",
30+
aggregate: bool = True,
31+
) -> Callable:
32+
return self._publish_registry.register(path=path, aggregate=aggregate)
33+
34+
def async_on_publish(
35+
self,
36+
path: str = "/default/*",
37+
aggregate: bool = True,
38+
) -> Callable:
39+
return self._async_publish_registry.register(path=path, aggregate=aggregate)
40+
41+
def on_subscribe(
42+
self,
43+
path: str = "/default/*",
44+
) -> Callable:
45+
return self._subscribe_registry.register(path=path)
46+
47+
def async_on_subscribe(
48+
self,
49+
path: str = "/default/*",
50+
) -> Callable:
51+
return self._async_subscribe_registry.register(path=path)
52+
53+
def append_context(self, **additional_context):
54+
"""Append key=value data as routing context"""
55+
self.context.update(**additional_context)
56+
57+
def clear_context(self):
58+
"""Resets routing context"""
59+
self.context.clear()

0 commit comments

Comments
 (0)