|
11 | 11 |
|
12 | 12 | from aws_lambda_decorators.classes import ExceptionHandler, Parameter, SSMParameter, ValidatedParameter |
13 | 13 | from aws_lambda_decorators.decorators import extract, extract_from_event, extract_from_context, handle_exceptions, \ |
14 | | - log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors |
| 14 | + log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors, push_ws_errors, \ |
| 15 | + push_ws_response |
| 16 | +from aws_lambda_decorators.utils import get_websocket_endpoint |
15 | 17 | from aws_lambda_decorators.validators import Mandatory, RegexValidator, SchemaValidator, Minimum, Maximum, MaxLength, \ |
16 | 18 | MinLength, Type, EnumValidator, NonEmpty, DateValidator, CurrencyValidator |
17 | 19 |
|
@@ -1833,6 +1835,140 @@ def handler(event, a=None): # noqa: pylint - unused-argument |
1833 | 1835 | "event", |
1834 | 1836 | "/a") |
1835 | 1837 |
|
| 1838 | + @patch("boto3.client") |
| 1839 | + def test_push_ws_errors_missing_parameter(self, mock_boto3_client): |
| 1840 | + get_websocket_endpoint.cache_clear() |
| 1841 | + |
| 1842 | + event = { |
| 1843 | + "requestContext": { |
| 1844 | + "connectionId": "test_connection_id" |
| 1845 | + }, |
| 1846 | + "body": json.dumps({ |
| 1847 | + "invalid_property": "invalid_value" |
| 1848 | + }) |
| 1849 | + } |
| 1850 | + |
| 1851 | + @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") |
| 1852 | + @extract_from_event(parameters=[ |
| 1853 | + Parameter(path="body[json]/valid_property", validators=[Mandatory]) |
| 1854 | + ]) |
| 1855 | + def lambda_func(event, context, valid_property=None): # noqa: pylint - unused-argument |
| 1856 | + return {"statusCode": HTTPStatus.OK} |
| 1857 | + |
| 1858 | + response = lambda_func(event, None) |
| 1859 | + |
| 1860 | + expected_data = { |
| 1861 | + "type": "error", |
| 1862 | + "statusCode": 400, |
| 1863 | + "message": [{ |
| 1864 | + "valid_property": ["Missing mandatory value"] |
| 1865 | + }] |
| 1866 | + } |
| 1867 | + |
| 1868 | + self.assertEqual(response["statusCode"], HTTPStatus.BAD_REQUEST) |
| 1869 | + |
| 1870 | + mock_boto3_client.assert_called_once_with( |
| 1871 | + "apigatewaymanagementapi", |
| 1872 | + endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod" |
| 1873 | + ) |
| 1874 | + |
| 1875 | + mock_boto3_client.return_value.post_to_connection.assert_called_once_with( |
| 1876 | + ConnectionId="test_connection_id", |
| 1877 | + Data=json.dumps(expected_data) |
| 1878 | + ) |
| 1879 | + |
| 1880 | + @patch("boto3.client") |
| 1881 | + def test_push_ws_errors_no_action_on_success(self, mock_boto3_client): |
| 1882 | + get_websocket_endpoint.cache_clear() |
| 1883 | + |
| 1884 | + event = { |
| 1885 | + "requestContext": { |
| 1886 | + "connectionId": "test_connection_id" |
| 1887 | + } |
| 1888 | + } |
| 1889 | + |
| 1890 | + @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") |
| 1891 | + def lambda_func(event, context): # noqa: pylint - unused-argument |
| 1892 | + return {"statusCode": HTTPStatus.OK} |
| 1893 | + |
| 1894 | + response = lambda_func(event, None) |
| 1895 | + |
| 1896 | + self.assertEqual(response["statusCode"], 200) |
| 1897 | + |
| 1898 | + mock_boto3_client.return_value.post_to_connection.assert_not_called() |
| 1899 | + |
| 1900 | + @patch("boto3.client") |
| 1901 | + def test_push_ws_errors_no_connection_id(self, mock_boto3_client): |
| 1902 | + get_websocket_endpoint.cache_clear() |
| 1903 | + |
| 1904 | + event = { |
| 1905 | + "body": { |
| 1906 | + "property": "value" |
| 1907 | + } |
| 1908 | + } |
| 1909 | + |
| 1910 | + @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") |
| 1911 | + def lambda_func(event, context): # noqa: pylint - unused-argument |
| 1912 | + return {"statusCode": HTTPStatus.BAD_REQUEST} |
| 1913 | + |
| 1914 | + response = lambda_func(event, None) |
| 1915 | + |
| 1916 | + self.assertEqual(response["statusCode"], 400) |
| 1917 | + |
| 1918 | + mock_boto3_client.return_value.post_to_connection.assert_not_called() |
| 1919 | + |
| 1920 | + @patch("boto3.client") |
| 1921 | + def test_push_ws_response(self, mock_boto3_client): |
| 1922 | + get_websocket_endpoint.cache_clear() |
| 1923 | + |
| 1924 | + event = { |
| 1925 | + "requestContext": { |
| 1926 | + "connectionId": "test_connection_id" |
| 1927 | + } |
| 1928 | + } |
| 1929 | + |
| 1930 | + @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") |
| 1931 | + def lambda_func(event, context): # noqa: pylint - unused-argument |
| 1932 | + return { |
| 1933 | + "statusCode": HTTPStatus.OK, |
| 1934 | + "body": "Hello, world!" |
| 1935 | + } |
| 1936 | + |
| 1937 | + response = lambda_func(event, None) |
| 1938 | + |
| 1939 | + self.assertEqual(response["statusCode"], 200) |
| 1940 | + self.assertEqual(response["body"], "Hello, world!") |
| 1941 | + |
| 1942 | + mock_boto3_client.return_value.post_to_connection.assert_called_once_with( |
| 1943 | + ConnectionId="test_connection_id", |
| 1944 | + Data="{\"statusCode\": 200, \"body\": \"Hello, world!\"}" |
| 1945 | + ) |
| 1946 | + |
| 1947 | + @patch("boto3.client") |
| 1948 | + def test_push_ws_response_no_connection_id(self, mock_boto3_client): |
| 1949 | + get_websocket_endpoint.cache_clear() |
| 1950 | + |
| 1951 | + event = { |
| 1952 | + "body": "Hello, world!" |
| 1953 | + } |
| 1954 | + |
| 1955 | + @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") |
| 1956 | + @extract_from_event(parameters=[ |
| 1957 | + Parameter(path="body") |
| 1958 | + ]) |
| 1959 | + def lambda_func(event, context, body=None): # noqa: pylint - unused-argument |
| 1960 | + return { |
| 1961 | + "statusCode": HTTPStatus.OK, |
| 1962 | + "body": body |
| 1963 | + } |
| 1964 | + |
| 1965 | + response = lambda_func(event, None) |
| 1966 | + |
| 1967 | + self.assertEqual(response["statusCode"], 200) |
| 1968 | + self.assertEqual(response["body"], "Hello, world!") |
| 1969 | + |
| 1970 | + mock_boto3_client.return_value.post_to_connection.assert_not_called() |
| 1971 | + |
1836 | 1972 |
|
1837 | 1973 | class IsolatedDecoderTests(unittest.TestCase): |
1838 | 1974 | # Tests have been named so they run in a specific order |
|
0 commit comments