|
1 | 1 | import unittest |
| 2 | +import os |
2 | 3 | from unittest.mock import patch, MagicMock, Mock, create_autospec |
3 | 4 | from mns_service import MnsService |
4 | 5 | from authentication import AppRestrictedAuth |
5 | | -from models.errors import ServerError, UnhandledResponseError |
| 6 | +from models.errors import ServerError, UnhandledResponseError, TokenValidationError |
6 | 7 |
|
7 | 8 |
|
| 9 | +SQS_ARN = "arn:aws:sqs:eu-west-2:123456789012:my-queue" |
| 10 | + |
| 11 | +@patch("mns_service.SQS_ARN", SQS_ARN) |
8 | 12 | class TestMnsService(unittest.TestCase): |
9 | 13 | def setUp(self): |
10 | 14 | # Common mock setup |
11 | 15 | self.authenticator = create_autospec(AppRestrictedAuth) |
12 | 16 | self.authenticator.get_access_token.return_value = "mocked_token" |
13 | 17 | self.mock_secret_manager = Mock() |
14 | 18 | self.mock_cache = Mock() |
| 19 | + self.sqs = SQS_ARN |
15 | 20 |
|
16 | 21 | @patch("mns_service.requests.post") |
17 | 22 | @patch("mns_service.requests.get") |
@@ -67,6 +72,76 @@ def test_unhandled_error(self, mock_post): |
67 | 72 |
|
68 | 73 | self.assertIn("Internal Server Error", str(context.exception)) |
69 | 74 |
|
| 75 | + @patch.dict(os.environ, {"SQS_ARN": "arn:aws:sqs:eu-west-2:123456789012:my-queue"}) |
| 76 | + @patch("mns_service.requests.get") |
| 77 | + def test_get_subscription_success(self, mock_get): |
| 78 | + """Should return the resource dict when a matching subscription exists.""" |
| 79 | + # Arrange a bundle with a matching entry |
| 80 | + mock_response = MagicMock() |
| 81 | + mock_response.status_code = 200 |
| 82 | + mock_response.json.return_value = { |
| 83 | + "entry": [ |
| 84 | + { |
| 85 | + "channel": {"endpoint": SQS_ARN}, |
| 86 | + "id": "123" |
| 87 | + } |
| 88 | + ] |
| 89 | + } |
| 90 | + mock_get.return_value = mock_response |
| 91 | + |
| 92 | + service = MnsService(self.authenticator) |
| 93 | + result2 = service.get_subscription() |
| 94 | + self.assertIsNotNone(result2) |
| 95 | + self.assertEqual(result2["channel"]["endpoint"], SQS_ARN) |
| 96 | + |
| 97 | + @patch("mns_service.requests.get") |
| 98 | + def test_get_subscription_no_match(self, mock_get): |
| 99 | + """Should return None when no subscription matches.""" |
| 100 | + mock_response = MagicMock() |
| 101 | + mock_response.status_code = 200 |
| 102 | + mock_response.json.return_value = {"entry": []} # No matches |
| 103 | + mock_get.return_value = mock_response |
| 104 | + |
| 105 | + service = MnsService(self.authenticator) |
| 106 | + result = service.get_subscription() |
| 107 | + self.assertIsNone(result) |
| 108 | + |
| 109 | + @patch("mns_service.requests.get") |
| 110 | + def test_get_subscription_401(self, mock_get): |
| 111 | + """Should raise TokenValidationError for 401.""" |
| 112 | + mock_response = MagicMock() |
| 113 | + mock_response.status_code = 401 |
| 114 | + mock_response.json.return_value = {"fault": {"faultstring": "Invalid Access Token"}} |
| 115 | + mock_get.return_value = mock_response |
| 116 | + |
| 117 | + service = MnsService(self.authenticator) |
| 118 | + with self.assertRaises(TokenValidationError): |
| 119 | + service.get_subscription() |
| 120 | + |
| 121 | + # Similarly, you can add tests for 400, 403, 500, etc. |
| 122 | + |
| 123 | + @patch("mns_service.requests.post") |
| 124 | + @patch("mns_service.requests.get") |
| 125 | + def test_check_subscription_creates_if_not_found(self, mock_get, mock_post): |
| 126 | + """If GET finds nothing, POST is called and returned.""" |
| 127 | + # Arrange GET returns no match |
| 128 | + mock_get_response = MagicMock() |
| 129 | + mock_get_response.status_code = 200 |
| 130 | + mock_get_response.json.return_value = {"entry": []} |
| 131 | + mock_get.return_value = mock_get_response |
| 132 | + |
| 133 | + # Arrange POST returns a new subscription |
| 134 | + mock_post_response = MagicMock() |
| 135 | + mock_post_response.status_code = 201 |
| 136 | + mock_post_response.json.return_value = {"subscriptionId": "abc123"} |
| 137 | + mock_post.return_value = mock_post_response |
| 138 | + |
| 139 | + service = MnsService(self.authenticator) |
| 140 | + result = service.check_subscription() |
| 141 | + self.assertEqual(result, {"subscriptionId": "abc123"}) |
| 142 | + mock_get.assert_called_once() |
| 143 | + mock_post.assert_called_once() |
| 144 | + |
70 | 145 |
|
71 | 146 | if __name__ == "__main__": |
72 | 147 | unittest.main() |
0 commit comments