Skip to content

Commit 2df3a76

Browse files
Work work work..
1 parent 49bf21d commit 2df3a76

File tree

8 files changed

+158
-68
lines changed

8 files changed

+158
-68
lines changed

aws_lambda_powertools/event_handler/appsync_events.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,22 @@
33
import asyncio
44
import logging
55
import warnings
6-
from typing import TYPE_CHECKING, Any, Callable
6+
from typing import TYPE_CHECKING, Any
77

88
from aws_lambda_powertools.event_handler.events_appsync.router import Router
99
from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent
1010
from aws_lambda_powertools.warnings import PowertoolsUserWarning
1111

1212
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
15+
from aws_lambda_powertools.event_handler.events_appsync.types import ResolverTypeDef
1316
from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext
1417

18+
1519
logger = logging.getLogger(__name__)
20+
21+
1622
class AppSyncEventsResolver(Router):
1723
"""
1824
AppSync Events API Resolver
@@ -25,35 +31,37 @@ def __init__(self):
2531

2632
def __call__(
2733
self,
28-
event: dict,
34+
event: dict | AppSyncResolverEventsEvent,
2935
context: LambdaContext,
3036
) -> Any:
3137
"""Implicit lambda handler which internally calls `resolve`"""
3238
return self.resolve(event, context)
3339

3440
def resolve(
3541
self,
36-
event: AppSyncResolverEventsEvent,
42+
event: dict | AppSyncResolverEventsEvent,
3743
context: LambdaContext,
3844
) -> Any:
3945
"""Resolves the response based on the provide event and decorator operation and namespaces"""
4046

4147
self.lambda_context = context
4248
Router.lambda_context = context
4349

44-
Router.current_event = AppSyncResolverEventsEvent(event)
50+
Router.current_event = (
51+
event if isinstance(event, AppSyncResolverEventsEvent) else AppSyncResolverEventsEvent(event)
52+
)
4553
self.current_event = Router.current_event
4654

4755
if self.current_event.info.operation == "PUBLISH":
48-
return self._call_publish_events(payload=self.current_event.events)
56+
return self._publish_events(payload=self.current_event.events)
4957

50-
response = self._call_subscribe_events()
58+
response = self._subscribe_events()
5159

5260
self.clear_context()
5361

5462
return response
5563

56-
def _call_subscribe_events(self) -> Any:
64+
def _subscribe_events(self) -> Any:
5765
logger.debug(f"Processing subscribe events for path {self.current_event.info.channel_path}")
5866

5967
resolver = self._subscribe_registry.find_resolver(self.current_event.info.channel_path)
@@ -66,7 +74,7 @@ def _call_subscribe_events(self) -> Any:
6674
return
6775
pass
6876

69-
def _call_publish_events(self, payload: list[dict[str, Any]]) -> Any:
77+
def _publish_events(self, payload: list[dict[str, Any]]) -> list[dict[str, Any]] | dict[str, Any]:
7078
"""Call single event resolver
7179
7280
Parameters
@@ -90,34 +98,33 @@ def _call_publish_events(self, payload: list[dict[str, Any]]) -> Any:
9098

9199
if resolver:
92100
logger.debug(f"Found sync resolver. {resolver}")
93-
return self._call_publish_event_sync_resolver(
94-
resolver=resolver["func"],
95-
aggregate=resolver["aggregate"],
101+
return self._process_publish_event_sync_resolver(
102+
resolver=resolver,
96103
)
97104

98105
if async_resolver:
99106
logger.debug(f"Found async resolver. {resolver}")
100107
return asyncio.run(
101108
self._call_publish_event_async_resolver(
102-
resolver=async_resolver["func"],
103-
aggregate=async_resolver["aggregate"],
109+
resolver=async_resolver,
104110
),
105111
)
106112

107113
# No resolver found
108114
# Warning and returning AS IS
109115
warnings.warn(
110-
f"No resolvers were found for publish operations with path {self.current_event.info.channel_path}",
116+
f"No resolvers were found for publish operations with path {self.current_event.info.channel_path}"
117+
"We will return the entire payload as is",
111118
stacklevel=2,
112-
category=PowertoolsUserWarning)
119+
category=PowertoolsUserWarning,
120+
)
113121

114122
return {"events": payload}
115123

116-
def _call_publish_event_sync_resolver(
124+
def _process_publish_event_sync_resolver(
117125
self,
118-
resolver: Callable,
119-
aggregate: bool = True,
120-
) -> list[Any]:
126+
resolver: ResolverTypeDef,
127+
) -> list[dict[str, Any]] | dict[str, Any]:
121128
"""
122129
Calls a synchronous batch resolver function for each event in the current batch.
123130
@@ -140,34 +147,37 @@ def _call_publish_event_sync_resolver(
140147
"""
141148

142149
# Checks whether the entire batch should be processed at once
143-
if aggregate:
144-
# Process the entire batch
145-
response = resolver(payload=self.current_event.events)
150+
if resolver["aggregate"]:
151+
try:
152+
# Process the entire batch
153+
response = resolver["func"](payload=self.current_event.events)
146154

147-
if not isinstance(response, list):
148-
warnings.warn(
149-
"Response must be a list when using aggregate, AppSync will drop those events.",
150-
stacklevel=2,
151-
category=PowertoolsUserWarning)
152-
153-
return response
155+
if not isinstance(response, list):
156+
warnings.warn(
157+
"Response must be a list when using aggregate, AppSync will drop those events.",
158+
stacklevel=2,
159+
category=PowertoolsUserWarning,
160+
)
154161

162+
return response
163+
except Exception as error:
164+
return {"error": self.format_error_response(error)}
155165

156166
# By default, we gracefully append `None` for any records that failed processing
157167
results = []
158168
for idx, event in enumerate(self.current_event.events):
159169
try:
160-
results.append(resolver(payload=event))
161-
except Exception:
170+
results.append(resolver["func"](payload=event))
171+
except Exception as error:
162172
logger.debug(f"Failed to process event number {idx}")
163-
results.append(None)
173+
error_return = {"id": event.get("id"), "error": self.format_error_response(error)}
174+
results.append(error_return)
164175

165176
return results
166177

167178
async def _call_publish_event_async_resolver(
168179
self,
169-
resolver: Callable,
170-
aggregate: bool = True,
180+
resolver: ResolverTypeDef,
171181
) -> list[Any]:
172182
"""
173183
Asynchronously call a batch resolver for each event in the current batch.
@@ -191,28 +201,55 @@ async def _call_publish_event_async_resolver(
191201
"""
192202

193203
# Checks whether the entire batch should be processed at once
194-
if aggregate:
204+
if resolver["aggregate"]:
195205
# Process the entire batch
196-
response = await resolver(event=self.current_batch_event)
206+
response = await resolver["func"](event=self.current_event.events)
197207
if not isinstance(response, list):
198208
warnings.warn(
199209
"Response must be a list when using aggregate, AppSync will drop those events.",
200210
stacklevel=2,
201-
category=PowertoolsUserWarning)
211+
category=PowertoolsUserWarning,
212+
)
202213

203214
return response
204215

205-
response: list = []
216+
response_async: list = []
206217

207218
# Prime coroutines
208-
tasks = [resolver(event=e, **e.arguments) for e in self.current_batch_event]
219+
tasks = [resolver["func"](event=e) for e in self.current_event.events]
209220

210221
# Aggregate results and exceptions, then filter them out
211222
# Use `None` upon exception for graceful error handling at GraphQL engine level
212223
#
213224
# NOTE: asyncio.gather(return_exceptions=True) catches and includes exceptions in the results
214225
# this will become useful when we support exception handling in AppSync resolver
215226
results = await asyncio.gather(*tasks, return_exceptions=True)
216-
response.extend(None if isinstance(ret, Exception) else ret for ret in results)
227+
response_async.extend(None if isinstance(ret, Exception) else ret for ret in results)
217228

218-
return response
229+
return response_async
230+
231+
def include_router(self, router: Router) -> None:
232+
"""Adds all resolvers defined in a router
233+
234+
Parameters
235+
----------
236+
router : Router
237+
A router containing a dict of field resolvers
238+
"""
239+
240+
# Merge app and router context
241+
logger.debug("Merging router and app context")
242+
self.context.update(**router.context)
243+
244+
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
245+
router.context = self.context
246+
247+
logger.debug("Merging router resolver registries")
248+
self._publish_registry.merge(router._publish_registry)
249+
self._async_publish_registry.merge(router._async_publish_registry)
250+
self._subscribe_registry.merge(router._subscribe_registry)
251+
252+
def format_error_response(self, error=None) -> str:
253+
if isinstance(error, Exception):
254+
return f"{error.__class__.__name__} - {str(error)}"
255+
return "An unknown error occurred"

aws_lambda_powertools/event_handler/events_appsync/_registry.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22

33
import logging
44
import warnings
5-
from typing import Any, Callable, Literal
5+
from typing import TYPE_CHECKING
66

77
from aws_lambda_powertools.event_handler.events_appsync.functions import find_best_route, is_valid_path
88
from aws_lambda_powertools.warnings import PowertoolsUserWarning
99

10+
if TYPE_CHECKING:
11+
from collections.abc import Callable
12+
13+
from aws_lambda_powertools.event_handler.events_appsync.types import ResolverTypeDef
14+
15+
1016
logger = logging.getLogger(__name__)
1117

12-
KIND_EVENT = Literal["on_publish", "async_on_publish", "on_subscribe"]
1318

1419
class ResolverEventsRegistry:
20+
1521
def __init__(self, kind_resolver: str):
16-
self.resolvers: dict[str, dict[str, Any]] = {}
22+
self.resolvers: dict[str, ResolverTypeDef] = {}
1723
self.kind_resolver = kind_resolver
1824

1925
def register(
@@ -41,14 +47,14 @@ def register(
4147
warnings.warn(
4248
f"The path `{path}` registered for `{self.kind_resolver}` is not valid and will be skipped."
4349
f"A path should always have a namespace starting with '/'"
44-
"A path can have multiple namespaces, all separated by '/'."
50+
"A path can have multiple namespaces, all separated by '/'."
4551
"Wildcards are allowed only at the end of the path.",
4652
stacklevel=2,
4753
category=PowertoolsUserWarning,
4854
)
4955

50-
5156
def _register(func) -> Callable:
57+
5258
print(f"Adding resolver `{func.__name__}` for path `{path}` and kind_resolver `{self.kind_resolver}`")
5359
self.resolvers[f"{path}"] = {
5460
"func": func,
@@ -58,7 +64,7 @@ def _register(func) -> Callable:
5864

5965
return _register
6066

61-
def find_resolver(self, path: str) -> dict | None:
67+
def find_resolver(self, path: str) -> ResolverTypeDef | None:
6268
"""Find resolver based on type_name and field_name
6369
6470
Parameters

aws_lambda_powertools/event_handler/events_appsync/functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
PATH_REGEX = re.compile(r"^\/([^\/\*]+)(\/[^\/\*]+)*(\/\*)?$")
88

9+
910
def is_valid_path(path: str) -> bool:
1011
"""
1112
Checks if a given path is valid based on specific rules.
@@ -17,7 +18,7 @@ def is_valid_path(path: str) -> bool:
1718
1819
Returns:
1920
--------
20-
bool:
21+
bool:
2122
True if the path is valid, False otherwise
2223
2324
Examples:
@@ -38,6 +39,7 @@ def is_valid_path(path: str) -> bool:
3839
return True
3940
return bool(PATH_REGEX.fullmatch(path))
4041

42+
4143
def find_best_route(routes: dict[str, Any], path: str):
4244
"""
4345
Find the most specific matching route for a given path.
@@ -58,13 +60,14 @@ def find_best_route(routes: dict[str, Any], path: str):
5860
'/path/specific/*': {'func': callable, 'aggregate': bool}
5961
}
6062
}
61-
path: str
63+
path: str
6264
Actual path to match (e.g., '/default/v1/users')
6365
6466
Returns
6567
-------
6668
str: Most specific matching route or None if no match
6769
"""
70+
6871
@lru_cache(maxsize=1024)
6972
def pattern_to_regex(route):
7073
"""
@@ -80,7 +83,7 @@ def pattern_to_regex(route):
8083
8184
Returns
8285
-------
83-
Pattern:
86+
Pattern:
8487
Compiled regex pattern
8588
"""
8689
# Escape special regex chars but convert * to regex pattern
@@ -94,10 +97,7 @@ def pattern_to_regex(route):
9497
return re.compile(f"^{pattern}$")
9598

9699
# Find all matching routes
97-
matches = [
98-
route for route in routes.keys()
99-
if pattern_to_regex(route).match(path)
100-
]
100+
matches = [route for route in routes.keys() if pattern_to_regex(route).match(path)]
101101

102102
# Return the most specific route (longest length minus wildcards)
103103
# Examples of specificity:

aws_lambda_powertools/event_handler/events_appsync/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
class Router(BaseRouter):
1414

1515
context: dict
16-
current_event: AppSyncResolverEventsEvent | None = None
17-
lambda_context: LambdaContext | None = None
16+
current_event: AppSyncResolverEventsEvent
17+
lambda_context: LambdaContext
1818

1919
def __init__(self):
2020
self.context = {} # early init as customers might add context before event resolution

0 commit comments

Comments
 (0)