|
4 | 4 | import time |
5 | 5 | import unittest |
6 | 6 | try: |
7 | | - from unittest.mock import patch, ANY, mock_open |
| 7 | + from unittest.mock import patch, ANY, mock_open, Mock |
8 | 8 | except: |
9 | | - from mock import patch, ANY, mock_open |
| 9 | + from mock import patch, ANY, mock_open, Mock |
10 | 10 | import requests |
11 | 11 |
|
12 | 12 | from tests.http_client import MinimalResponse |
13 | 13 | from msal import ( |
14 | 14 | SystemAssignedManagedIdentity, UserAssignedManagedIdentity, |
15 | 15 | ManagedIdentityClient, |
16 | 16 | ManagedIdentityError, |
| 17 | + ArcPlatformNotSupportedError, |
17 | 18 | ) |
| 19 | +from msal.managed_identity import _supported_arc_platforms_and_their_prefixes |
18 | 20 |
|
19 | 21 |
|
20 | 22 | class ManagedIdentityTestCase(unittest.TestCase): |
@@ -194,29 +196,41 @@ def test_sf_error_should_be_normalized(self): |
194 | 196 | new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function. |
195 | 197 | # https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch |
196 | 198 | ) |
| 199 | +@patch("os.stat", return_value=Mock(st_size=4096)) |
197 | 200 | class ArcTestCase(ClientTestCase): |
198 | 201 | challenge = MinimalResponse(status_code=401, text="", headers={ |
199 | 202 | "WWW-Authenticate": "Basic realm=/tmp/foo", |
200 | 203 | }) |
201 | 204 |
|
202 | | - def test_happy_path(self): |
| 205 | + def test_happy_path(self, mocked_stat): |
203 | 206 | with patch.object(self.app._http_client, "get", side_effect=[ |
204 | 207 | self.challenge, |
205 | 208 | MinimalResponse( |
206 | 209 | status_code=200, |
207 | 210 | text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', |
208 | 211 | ), |
209 | 212 | ]) as mocked_method: |
210 | | - super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) |
211 | | - |
212 | | - def test_arc_error_should_be_normalized(self): |
| 213 | + try: |
| 214 | + super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) |
| 215 | + mocked_stat.assert_called_with(os.path.join( |
| 216 | + _supported_arc_platforms_and_their_prefixes[sys.platform], |
| 217 | + "foo.key")) |
| 218 | + except ArcPlatformNotSupportedError: |
| 219 | + if sys.platform in _supported_arc_platforms_and_their_prefixes: |
| 220 | + self.fail("Should not raise ArcPlatformNotSupportedError") |
| 221 | + |
| 222 | + def test_arc_error_should_be_normalized(self, mocked_stat): |
213 | 223 | with patch.object(self.app._http_client, "get", side_effect=[ |
214 | 224 | self.challenge, |
215 | 225 | MinimalResponse(status_code=400, text="undefined"), |
216 | 226 | ]) as mocked_method: |
217 | | - self.assertEqual({ |
218 | | - "error": "invalid_request", |
219 | | - "error_description": "undefined", |
220 | | - }, self.app.acquire_token_for_client(resource="R")) |
221 | | - self.assertEqual({}, self.app._token_cache._cache) |
| 227 | + try: |
| 228 | + self.assertEqual({ |
| 229 | + "error": "invalid_request", |
| 230 | + "error_description": "undefined", |
| 231 | + }, self.app.acquire_token_for_client(resource="R")) |
| 232 | + self.assertEqual({}, self.app._token_cache._cache) |
| 233 | + except ArcPlatformNotSupportedError: |
| 234 | + if sys.platform in _supported_arc_platforms_and_their_prefixes: |
| 235 | + self.fail("Should not raise ArcPlatformNotSupportedError") |
222 | 236 |
|
0 commit comments