Skip to content

Commit 1e3a327

Browse files
authored
Merge pull request #114 from gridsmartercities/websocket_error_reporting
Websocket response/error pushing
2 parents b683563 + 90b90af commit 1e3a327

5 files changed

Lines changed: 271 additions & 7 deletions

File tree

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The current list of AWS Lambda Python Decorators includes:
4444
* [__response_body_as_json__](#response_body_as_json): a decorator to transform a response dictionary body to a json string.
4545
* [__handle_all_exceptions__](#handle_all_exceptions): a decorator to handle all exceptions thrown by the lambda function.
4646
* [__cors__](#cors): a decorator to add cors headers to a lambda function.
47+
* [__push_ws_errors__](#push_ws_errors): a decorator to push unsuccessful responses back to the calling user via websockets with api gateway.
48+
* [__push_ws_responses__](#push_ws_response): a decorator to push all responses back to the calling user via websockets with api gateway.
4749

4850

4951
### [Validators](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/validators.py)
@@ -456,6 +458,45 @@ def cors_example():
456458
cors_example() # returns {'statusCode': 200, 'headers': {'access-control-allow-origin': '*', 'access-control-allow-methods': 'POST', 'access-control-allow-headers': 'Content-Type', 'access-control-max-age': 86400}}
457459
```
458460

461+
### push_ws_errors
462+
463+
This decorator pushes unsuccessful responses back to the calling client over websockets built on api gateway
464+
465+
This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id
466+
467+
Example:
468+
```py
469+
@push_ws_errors('https://api_id.execute_id.region.amazonaws.com/Prod')
470+
@handle_all_exceptions()
471+
def handler(event, context):
472+
return {
473+
'statusCode': 400,
474+
'body': {
475+
'message': 'Bad request'
476+
}
477+
}
478+
479+
# will push {'type': 'error', 'statusCode': 400, 'message': 'Bad request'} back to the client via websockets
480+
```
481+
482+
### push_ws_response
483+
484+
This decorator pushes all responses back to the calling client over websockets built on api gateway
485+
486+
This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id
487+
488+
Example:
489+
```py
490+
@push_ws_response('https://api_id.execute_id.region.amazonaws.com/Prod')
491+
def handler(event, context):
492+
return {
493+
'statusCode': 200,
494+
'body': 'Hello, world!'
495+
}
496+
497+
# will push {'statusCode': 200, 'body': 'Hello, world!'} back to the client via websockets
498+
```
499+
459500
## Writing your own validators
460501

461502
You can create your own validators by inheriting from the Validator class.

aws_lambda_decorators/decorators.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import json
88
from http import HTTPStatus
99
import boto3
10-
from aws_lambda_decorators.utils import full_name, all_func_args, find_key_case_insensitive, failure, get_logger
10+
from aws_lambda_decorators.utils import (full_name, all_func_args, find_key_case_insensitive, failure, get_logger,
11+
find_websocket_connection_id, get_websocket_endpoint)
1112

1213

1314
LOGGER = get_logger(__name__)
@@ -334,3 +335,50 @@ def update_header(headers, header_name, value, value_type):
334335
return failure(CORS_NON_DICT_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
335336
return wrapper
336337
return decorator
338+
339+
340+
def push_ws_errors(websocket_endpoint_url: str):
341+
def decorator(func):
342+
def wrapper(*args, **kwargs):
343+
connection_id = find_websocket_connection_id(args)
344+
345+
response = func(*args, **kwargs)
346+
success = response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR).value < 300
347+
348+
if connection_id and not success:
349+
websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url)
350+
351+
ws_response = {
352+
"type": "error",
353+
"statusCode": response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR),
354+
"message": json.loads(response.get("body", "{}")).get("message")
355+
}
356+
357+
websocket_endpoint.post_to_connection(
358+
ConnectionId=connection_id,
359+
Data=json.dumps(ws_response)
360+
)
361+
362+
return response
363+
return wrapper
364+
return decorator
365+
366+
367+
def push_ws_response(websocket_endpoint_url: str):
368+
def decorator(func):
369+
def wrapper(*args, **kwargs):
370+
connection_id = find_websocket_connection_id(args)
371+
372+
response = func(*args, **kwargs)
373+
374+
if connection_id:
375+
websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url)
376+
377+
websocket_endpoint.post_to_connection(
378+
ConnectionId=connection_id,
379+
Data=json.dumps(response)
380+
)
381+
382+
return response
383+
return wrapper
384+
return decorator

aws_lambda_decorators/utils.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Utility functions."""
2-
import logging
3-
import os
4-
import json
2+
from functools import lru_cache
53
from http import HTTPStatus
6-
import keyword
74
import inspect
5+
import json
6+
import keyword
7+
import logging
8+
import os
9+
10+
import boto3
811

912

1013
LOG_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO"))
@@ -112,3 +115,39 @@ def failure(errors, status_code=HTTPStatus.BAD_REQUEST):
112115
"statusCode": status_code,
113116
"body": json.dumps({"message": errors})
114117
}
118+
119+
120+
def find_websocket_connection_id(args: list) -> str:
121+
"""
122+
Finds an API Gateway connection id from the event dictionary in the
123+
arguments of a lambda
124+
125+
Args:
126+
args (list): a list of arguments from a lambda (*args)
127+
128+
Returns:
129+
The connection id of a user as a string if found
130+
None if not
131+
"""
132+
for arg in args:
133+
if isinstance(arg, dict) and "requestContext" in arg:
134+
return arg["requestContext"].get("connectionId")
135+
return None
136+
137+
138+
@lru_cache()
139+
def get_websocket_endpoint(endpoint_url: str) -> "ApiGatewayManagementApi": # I can't find this in botocore.client
140+
"""
141+
Gets an instance of ApiGatewayManagementApi for sending messages
142+
through websockets
143+
144+
Args:
145+
endpoint_url (str): an api gateway connection url (ish)
146+
147+
Returns:
148+
The api gateway management client (cached)
149+
"""
150+
return boto3.client(
151+
"apigatewaymanagementapi",
152+
endpoint_url=endpoint_url
153+
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
LONG_DESCRIPTION = open("README.md").read()
44

55
setup(name="aws-lambda-decorators",
6-
version="0.47",
6+
version="0.48",
77
description="A set of python decorators to simplify aws python lambda development",
88
long_description=LONG_DESCRIPTION,
99
long_description_content_type="text/markdown",

tests/test_decorators.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from aws_lambda_decorators.classes import ExceptionHandler, Parameter, SSMParameter, ValidatedParameter
1313
from aws_lambda_decorators.decorators import extract, extract_from_event, extract_from_context, handle_exceptions, \
14-
log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors
14+
log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors, push_ws_errors, \
15+
push_ws_response
16+
from aws_lambda_decorators.utils import get_websocket_endpoint
1517
from aws_lambda_decorators.validators import Mandatory, RegexValidator, SchemaValidator, Minimum, Maximum, MaxLength, \
1618
MinLength, Type, EnumValidator, NonEmpty, DateValidator, CurrencyValidator
1719

@@ -1833,6 +1835,140 @@ def handler(event, a=None): # noqa: pylint - unused-argument
18331835
"event",
18341836
"/a")
18351837

1838+
@patch("boto3.client")
1839+
def test_push_ws_errors_missing_parameter(self, mock_boto3_client):
1840+
get_websocket_endpoint.cache_clear()
1841+
1842+
event = {
1843+
"requestContext": {
1844+
"connectionId": "test_connection_id"
1845+
},
1846+
"body": json.dumps({
1847+
"invalid_property": "invalid_value"
1848+
})
1849+
}
1850+
1851+
@push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1852+
@extract_from_event(parameters=[
1853+
Parameter(path="body[json]/valid_property", validators=[Mandatory])
1854+
])
1855+
def lambda_func(event, context, valid_property=None): # noqa: pylint - unused-argument
1856+
return {"statusCode": HTTPStatus.OK}
1857+
1858+
response = lambda_func(event, None)
1859+
1860+
expected_data = {
1861+
"type": "error",
1862+
"statusCode": 400,
1863+
"message": [{
1864+
"valid_property": ["Missing mandatory value"]
1865+
}]
1866+
}
1867+
1868+
self.assertEqual(response["statusCode"], HTTPStatus.BAD_REQUEST)
1869+
1870+
mock_boto3_client.assert_called_once_with(
1871+
"apigatewaymanagementapi",
1872+
endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod"
1873+
)
1874+
1875+
mock_boto3_client.return_value.post_to_connection.assert_called_once_with(
1876+
ConnectionId="test_connection_id",
1877+
Data=json.dumps(expected_data)
1878+
)
1879+
1880+
@patch("boto3.client")
1881+
def test_push_ws_errors_no_action_on_success(self, mock_boto3_client):
1882+
get_websocket_endpoint.cache_clear()
1883+
1884+
event = {
1885+
"requestContext": {
1886+
"connectionId": "test_connection_id"
1887+
}
1888+
}
1889+
1890+
@push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1891+
def lambda_func(event, context): # noqa: pylint - unused-argument
1892+
return {"statusCode": HTTPStatus.OK}
1893+
1894+
response = lambda_func(event, None)
1895+
1896+
self.assertEqual(response["statusCode"], 200)
1897+
1898+
mock_boto3_client.return_value.post_to_connection.assert_not_called()
1899+
1900+
@patch("boto3.client")
1901+
def test_push_ws_errors_no_connection_id(self, mock_boto3_client):
1902+
get_websocket_endpoint.cache_clear()
1903+
1904+
event = {
1905+
"body": {
1906+
"property": "value"
1907+
}
1908+
}
1909+
1910+
@push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1911+
def lambda_func(event, context): # noqa: pylint - unused-argument
1912+
return {"statusCode": HTTPStatus.BAD_REQUEST}
1913+
1914+
response = lambda_func(event, None)
1915+
1916+
self.assertEqual(response["statusCode"], 400)
1917+
1918+
mock_boto3_client.return_value.post_to_connection.assert_not_called()
1919+
1920+
@patch("boto3.client")
1921+
def test_push_ws_response(self, mock_boto3_client):
1922+
get_websocket_endpoint.cache_clear()
1923+
1924+
event = {
1925+
"requestContext": {
1926+
"connectionId": "test_connection_id"
1927+
}
1928+
}
1929+
1930+
@push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1931+
def lambda_func(event, context): # noqa: pylint - unused-argument
1932+
return {
1933+
"statusCode": HTTPStatus.OK,
1934+
"body": "Hello, world!"
1935+
}
1936+
1937+
response = lambda_func(event, None)
1938+
1939+
self.assertEqual(response["statusCode"], 200)
1940+
self.assertEqual(response["body"], "Hello, world!")
1941+
1942+
mock_boto3_client.return_value.post_to_connection.assert_called_once_with(
1943+
ConnectionId="test_connection_id",
1944+
Data="{\"statusCode\": 200, \"body\": \"Hello, world!\"}"
1945+
)
1946+
1947+
@patch("boto3.client")
1948+
def test_push_ws_response_no_connection_id(self, mock_boto3_client):
1949+
get_websocket_endpoint.cache_clear()
1950+
1951+
event = {
1952+
"body": "Hello, world!"
1953+
}
1954+
1955+
@push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1956+
@extract_from_event(parameters=[
1957+
Parameter(path="body")
1958+
])
1959+
def lambda_func(event, context, body=None): # noqa: pylint - unused-argument
1960+
return {
1961+
"statusCode": HTTPStatus.OK,
1962+
"body": body
1963+
}
1964+
1965+
response = lambda_func(event, None)
1966+
1967+
self.assertEqual(response["statusCode"], 200)
1968+
self.assertEqual(response["body"], "Hello, world!")
1969+
1970+
mock_boto3_client.return_value.post_to_connection.assert_not_called()
1971+
18361972

18371973
class IsolatedDecoderTests(unittest.TestCase):
18381974
# Tests have been named so they run in a specific order

0 commit comments

Comments
 (0)