From 59a97b060837cfb31ef03cd3aee9044b30f78a08 Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Thu, 23 Mar 2023 18:52:11 -0600 Subject: [PATCH 01/10] Initial Commit For Adding UserCredentialsStrategy --- changes/unknown.feature.md | 2 + hikari/impl/rest.py | 126 +++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 changes/unknown.feature.md diff --git a/changes/unknown.feature.md b/changes/unknown.feature.md new file mode 100644 index 0000000000..3ab1324678 --- /dev/null +++ b/changes/unknown.feature.md @@ -0,0 +1,2 @@ +Add `UserCredentialsStrategy` to `hikari\impl\rest.py` similar to `ClientCredentialsStrategy` to allow simple OAuth2 +token generation and refreshing diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 5e1b42d31e..d467fa0d0d 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -214,6 +214,132 @@ def invalidate(self, token: typing.Optional[str]) -> None: self._token = None +class UserCredentialsStrategy(rest_api.TokenStrategy): + """Strategy class for handling client credential OAuth2 authorization. + + Parameters + ---------- + client_id : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]] + Object or ID of the application this client credentials strategy should + authorize as. + client_secret : str + Client secret to use when authorizing. + auth_code : str + Auth code given from Discord when user authorizes + redirect_uri: str + The redirect uri that was included in the authorization request + + Other Parameters + ---------------- + scopes : typing.Sequence[str] + The scopes to authorize for. + """ + + __slots__: typing.Sequence[str] = ( + "_client_id", + "_client_secret", + "_exception", + "_expire_at", + "_lock", + "_scopes", + "_token", + "_code", + "_redirect_uri", + "_refresh_token", + ) + + def __init__( + self, + client_id: snowflakes.SnowflakeishOr[guilds.PartialApplication], + client_secret: str, + auth_code: str, + redirect_uri: str, + *, + scopes: typing.Sequence[typing.Union[applications.OAuth2Scope, str]] = (applications.OAuth2Scope.IDENTIFY,), + ) -> None: + self._client_id = snowflakes.Snowflake(client_id) + self._client_secret = client_secret + self._exception: typing.Optional[errors.ClientHTTPResponseError] = None + self._expire_at = 0.0 + self._lock = asyncio.Lock() + self._scopes = scopes + self._token: typing.Optional[str] = None + self._refresh_token = None + self._code = auth_code + self._redirect_uri = redirect_uri + + @property + def client_id(self) -> snowflakes.Snowflake: + """ID of the application this token strategy authenticates with.""" + return self._client_id + + @property + def _is_expired(self) -> bool: + return time.monotonic() >= self._expire_at + + @property + def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]]: + """Sequence of scopes this token strategy authenticates for.""" + return self._scopes + + @property + def token_type(self) -> applications.TokenType: + return applications.TokenType.BEARER + + async def acquire(self, client: rest_api.RESTClient | RESTApp) -> str: + if not self._code: + raise "Token has been invalidated. Unable to get current or new token" + + if self._token and not self._is_expired: + return self._token + + async with self._lock: + if self._token and not self._is_expired: + return self._token + + if self._exception: + # If we don't copy the exception then python keeps adding onto the stack each time it's raised. + raise copy.copy(self._exception) from None + + try: + if isinstance(client, RESTApp): + client = client.acquire() + client.start() + + if not self._token: + response = await client.authorize_access_token( + client=self._client_id, + client_secret=self._client_secret, + code=self._code, + redirect_uri=self._redirect_uri, + ) + else: + response = await client.refresh_access_token( + client=self._client_id, client_secret=self._client_secret, refresh_token=self._refresh_token + ) + + except errors.ClientHTTPResponseError as exc: + if not isinstance(exc, errors.RateLimitTooLongError): + # If we don't copy the exception then python keeps adding onto the stack each time it's raised. + self._exception = copy.copy(exc) + await client.close() + raise + + # Expires in is lowered a bit in-order to lower the chance of a dead token being used. + self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99) + self._token = f"{response.access_token}" + self._refresh_token = response.refresh_token + await client.close() + return self._token + + def invalidate(self, token: typing.Optional[str] = None) -> None: + if not token or token == self._token: + self._expire_at = 0.0 + self._token = None + self._refresh_token = None + self._code = None + + class _RESTProvider(traits.RESTAware): __slots__: typing.Sequence[str] = ("_entity_factory", "_executor", "_rest") From 69ad04d5fbd6a2c389edf37fdff8cc354837bdeb Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Thu, 23 Mar 2023 19:01:38 -0600 Subject: [PATCH 02/10] Rename 'unknown.feature.md' to '1558.feature.md' --- changes/{unknown.feature.md => 1558.feature.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changes/{unknown.feature.md => 1558.feature.md} (100%) diff --git a/changes/unknown.feature.md b/changes/1558.feature.md similarity index 100% rename from changes/unknown.feature.md rename to changes/1558.feature.md From d94d187fc703a135cbedd784731f3d5879c39d05 Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 09:26:13 -0600 Subject: [PATCH 03/10] Update changes/1558.feature.md Co-authored-by: davfsa Signed-off-by: theonlydoublee --- changes/1558.feature.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/changes/1558.feature.md b/changes/1558.feature.md index 3ab1324678..d80eaf37c3 100644 --- a/changes/1558.feature.md +++ b/changes/1558.feature.md @@ -1,2 +1 @@ -Add `UserCredentialsStrategy` to `hikari\impl\rest.py` similar to `ClientCredentialsStrategy` to allow simple OAuth2 -token generation and refreshing +Add `OAuthCredentialsStrategy` to `hikari.impl.rest` for OAuth2 flow tokens. From e8e89ca8cbedd55eb7fa10fe6b063713c20a6c3e Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 15:22:25 -0600 Subject: [PATCH 04/10] Add tests and Add Davfsa's Changes --- hikari/impl/rest.py | 37 +++---- tests/hikari/impl/test_rest.py | 193 +++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 22 deletions(-) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index d467fa0d0d..40ac82a051 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -214,12 +214,12 @@ def invalidate(self, token: typing.Optional[str]) -> None: self._token = None -class UserCredentialsStrategy(rest_api.TokenStrategy): - """Strategy class for handling client credential OAuth2 authorization. +class OAuthCredentialsStrategy(rest_api.TokenStrategy): + """Strategy class for handling OAuth2 authorization. Parameters ---------- - client_id : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]] + client : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]] Object or ID of the application this client credentials strategy should authorize as. client_secret : str @@ -243,21 +243,21 @@ class UserCredentialsStrategy(rest_api.TokenStrategy): "_lock", "_scopes", "_token", - "_code", + "_auth_code", "_redirect_uri", "_refresh_token", ) def __init__( self, - client_id: snowflakes.SnowflakeishOr[guilds.PartialApplication], + client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, auth_code: str, redirect_uri: str, *, scopes: typing.Sequence[typing.Union[applications.OAuth2Scope, str]] = (applications.OAuth2Scope.IDENTIFY,), ) -> None: - self._client_id = snowflakes.Snowflake(client_id) + self._client_id = snowflakes.Snowflake(client) self._client_secret = client_secret self._exception: typing.Optional[errors.ClientHTTPResponseError] = None self._expire_at = 0.0 @@ -265,7 +265,7 @@ def __init__( self._scopes = scopes self._token: typing.Optional[str] = None self._refresh_token = None - self._code = auth_code + self._auth_code = auth_code self._redirect_uri = redirect_uri @property @@ -273,7 +273,6 @@ def client_id(self) -> snowflakes.Snowflake: """ID of the application this token strategy authenticates with.""" return self._client_id - @property def _is_expired(self) -> bool: return time.monotonic() >= self._expire_at @@ -286,15 +285,15 @@ def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]] def token_type(self) -> applications.TokenType: return applications.TokenType.BEARER - async def acquire(self, client: rest_api.RESTClient | RESTApp) -> str: - if not self._code: - raise "Token has been invalidated. Unable to get current or new token" + async def acquire(self, client: rest_api.RESTClient) -> str: + if not self._auth_code: + raise RuntimeError("Token has been invalidated. Unable to get current or new token") - if self._token and not self._is_expired: + if self._token and not self._is_expired(): return self._token async with self._lock: - if self._token and not self._is_expired: + if self._token and not self._is_expired(): return self._token if self._exception: @@ -302,15 +301,11 @@ async def acquire(self, client: rest_api.RESTClient | RESTApp) -> str: raise copy.copy(self._exception) from None try: - if isinstance(client, RESTApp): - client = client.acquire() - client.start() - if not self._token: response = await client.authorize_access_token( client=self._client_id, client_secret=self._client_secret, - code=self._code, + code=self._auth_code, redirect_uri=self._redirect_uri, ) else: @@ -322,14 +317,12 @@ async def acquire(self, client: rest_api.RESTClient | RESTApp) -> str: if not isinstance(exc, errors.RateLimitTooLongError): # If we don't copy the exception then python keeps adding onto the stack each time it's raised. self._exception = copy.copy(exc) - await client.close() raise # Expires in is lowered a bit in-order to lower the chance of a dead token being used. self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99) - self._token = f"{response.access_token}" + self._token = str(response.access_token) self._refresh_token = response.refresh_token - await client.close() return self._token def invalidate(self, token: typing.Optional[str] = None) -> None: @@ -337,7 +330,7 @@ def invalidate(self, token: typing.Optional[str] = None) -> None: self._expire_at = 0.0 self._token = None self._refresh_token = None - self._code = None + self._auth_code = None class _RESTProvider(traits.RESTAware): diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 1a062e5d67..64df21af24 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -292,6 +292,199 @@ def test_invalidate_when_token_is_stored_token(self): assert strategy._expire_at == 0.0 assert strategy._token is None +############################# +# OAuthCredentialsStrategy # +############################# + + +class TestOAuthCredentialsStrategy: + @pytest.fixture() + def mock_token(self): + return mock.Mock( + applications.PartialOAuth2Token, + expires_in=datetime.timedelta(weeks=1), + token_type=applications.TokenType.BEARER, + access_token="mockmock.tokentoken.mocktoken", + refresh_token="refresh.mock.token" + ) + + def test_client_id_property(self): + mock_client = hikari_test_helpers.mock_class_namespace(applications.Application, id=41551, init_=False)() + strategy = rest.OAuthCredentialsStrategy(client=mock_client, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + + assert strategy.client_id == 41551 + + def test_scopes_property(self): + token = rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)) + + assert token.scopes == ("identify",) + + def test_token_type_property(self): + token = rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + assert token.token_type is applications.TokenType.BEARER + + @pytest.mark.asyncio() + async def test_acquire_on_new_instance(self, mock_token): + mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) + + result = await rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)).acquire(mock_rest) + + assert result == "mockmock.tokentoken.mocktoken" + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=987654321, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + @pytest.mark.asyncio() + async def test_acquire_handles_out_of_date_token(self, mock_token): + mock_old_token = mock.AsyncMock( + applications.PartialOAuth2Token, + expires_in=datetime.timedelta(weeks=1), + token_type=applications.TokenType.BEARER, + access_token="old.mock.token", + refresh_token="refresh.mock.token" + ) + mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + token = await strategy.acquire( + mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) + ) + + with mock.patch.object(time, "monotonic", return_value=99999999999): + new_token = await strategy.acquire(mock_rest) + + mock_rest.refresh_access_token.assert_awaited_once_with( + client=123456789, client_secret="123123123", refresh_token="refresh.mock.token" + ) + assert new_token != token + assert new_token == "mockmock.tokentoken.mocktoken" + + @pytest.mark.asyncio() + async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): + lock = asyncio.Lock() + mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(side_effect=[mock_token])) + + with mock.patch.object(asyncio, "Lock", return_value=lock): + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + + async with lock: + tokens_gather = asyncio.gather( + strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) + ) + + results = await tokens_gather + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=123456789, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + assert results == [ + "mockmock.tokentoken.mocktoken", + "mockmock.tokentoken.mocktoken", + "mockmock.tokentoken.mocktoken", + ] + + @pytest.mark.asyncio() + async def test_acquire_after_invalidation(self, mock_token): + mock_old_token = mock.AsyncMock( + applications.PartialOAuth2Token, + expires_in=datetime.timedelta(weeks=1), + token_type=applications.TokenType.BEARER, + access_token="okokok.fofdsasdasdofo.ddd", + refresh_token="refresh.mock.token" + ) + mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)) + token = await strategy.acquire( + mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) + ) + + strategy.invalidate(token) + with pytest.raises(RuntimeError): + await strategy.acquire(mock_rest) + + @pytest.mark.asyncio() + async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self): + class MockLock: + def __init__(self, strat): + self._strategy = strat + + async def __aenter__(self): + self._strategy._token = "cab.cab.cab" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return + + mock_rest = mock.AsyncMock() + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy._lock = MockLock(strategy) + strategy._token = None + strategy._expire_at = time.monotonic() + 500 + + result = await strategy.acquire(mock_rest) + + assert result == "cab.cab.cab" + + mock_rest.authorize_access_token.assert_not_called() + + @pytest.mark.asyncio() + async def test_acquire_caches_client_http_response_error(self): + mock_rest = mock.AsyncMock() + error = errors.ClientHTTPResponseError( + url="okokok", status=42, headers={}, raw_body=b"ok", message="OK", code=34123 + ) + mock_rest.authorize_access_token.side_effect = error + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + + with pytest.raises(errors.ClientHTTPResponseError): + await strategy.acquire(mock_rest) + + with pytest.raises(errors.ClientHTTPResponseError): + await strategy.acquire(mock_rest) + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=123456789, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + def test_invalidate_when_token_is_not_stored_token(self): + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate("tokena") + + assert strategy._expire_at == 10.0 + assert strategy._token == "token" + + def test_invalidate_when_no_token_specified(self): + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate(None) + + assert strategy._expire_at == 0.0 + assert strategy._token is None + + def test_invalidate_when_token_is_stored_token(self): + strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate("token") + + assert strategy._expire_at == 0.0 + assert strategy._token is None + ########### # RESTApp # From 07222474b84aabc0f441cf65d753211baeca5a56 Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 15:35:29 -0600 Subject: [PATCH 05/10] Ran black on test_rest.py on changes --- tests/hikari/impl/test_rest.py | 104 ++++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 20 deletions(-) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 64df21af24..19633580a2 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -292,6 +292,7 @@ def test_invalidate_when_token_is_stored_token(self): assert strategy._expire_at == 0.0 assert strategy._token is None + ############################# # OAuthCredentialsStrategy # ############################# @@ -305,29 +306,51 @@ def mock_token(self): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="mockmock.tokentoken.mocktoken", - refresh_token="refresh.mock.token" + refresh_token="refresh.mock.token", ) def test_client_id_property(self): mock_client = hikari_test_helpers.mock_class_namespace(applications.Application, id=41551, init_=False)() - strategy = rest.OAuthCredentialsStrategy(client=mock_client, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=mock_client, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) assert strategy.client_id == 41551 def test_scopes_property(self): - token = rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)) + token = rest.OAuthCredentialsStrategy( + client=987654321, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + scopes=("identify",), + ) assert token.scopes == ("identify",) def test_token_type_property(self): - token = rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + token = rest.OAuthCredentialsStrategy( + client=987654321, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) assert token.token_type is applications.TokenType.BEARER @pytest.mark.asyncio() async def test_acquire_on_new_instance(self, mock_token): mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - result = await rest.OAuthCredentialsStrategy(client=987654321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)).acquire(mock_rest) + result = await rest.OAuthCredentialsStrategy( + client=987654321, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + scopes=("identify",), + ).acquire(mock_rest) assert result == "mockmock.tokentoken.mocktoken" @@ -345,10 +368,15 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="old.mock.token", - refresh_token="refresh.mock.token" + refresh_token="refresh.mock.token", ) mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) token = await strategy.acquire( mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) ) @@ -368,7 +396,12 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(side_effect=[mock_token])) with mock.patch.object(asyncio, "Lock", return_value=lock): - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) async with lock: tokens_gather = asyncio.gather( @@ -396,10 +429,16 @@ async def test_acquire_after_invalidation(self, mock_token): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="okokok.fofdsasdasdofo.ddd", - refresh_token="refresh.mock.token" + refresh_token="refresh.mock.token", ) mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", scopes=("identify",)) + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + scopes=("identify",), + ) token = await strategy.acquire( mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) ) @@ -422,7 +461,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): return mock_rest = mock.AsyncMock() - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) strategy._lock = MockLock(strategy) strategy._token = None strategy._expire_at = time.monotonic() + 500 @@ -440,7 +484,12 @@ async def test_acquire_caches_client_http_response_error(self): url="okokok", status=42, headers={}, raw_body=b"ok", message="OK", code=34123 ) mock_rest.authorize_access_token.side_effect = error - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) with pytest.raises(errors.ClientHTTPResponseError): await strategy.acquire(mock_rest) @@ -449,14 +498,19 @@ async def test_acquire_caches_client_http_response_error(self): await strategy.acquire(mock_rest) mock_rest.authorize_access_token.assert_awaited_once_with( - client=123456789, - client_secret="123123123", - code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + client=123456789, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) def test_invalidate_when_token_is_not_stored_token(self): - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) strategy._expire_at = 10.0 strategy._token = "token" @@ -466,7 +520,12 @@ def test_invalidate_when_token_is_not_stored_token(self): assert strategy._token == "token" def test_invalidate_when_no_token_specified(self): - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) strategy._expire_at = 10.0 strategy._token = "token" @@ -476,7 +535,12 @@ def test_invalidate_when_no_token_specified(self): assert strategy._token is None def test_invalidate_when_token_is_stored_token(self): - strategy = rest.OAuthCredentialsStrategy(client=123456789, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord") + strategy = rest.OAuthCredentialsStrategy( + client=123456789, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) strategy._expire_at = 10.0 strategy._token = "token" From 68e83fb4e43f238e79da83187ddeff48cfa010fb Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 15:50:44 -0600 Subject: [PATCH 06/10] Happy Mypy in nox --- tests/hikari/impl/test_rest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 19633580a2..b86cf6079d 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -306,7 +306,7 @@ def mock_token(self): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="mockmock.tokentoken.mocktoken", - refresh_token="refresh.mock.token", + refresh_token=7654, ) def test_client_id_property(self): @@ -368,7 +368,7 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="old.mock.token", - refresh_token="refresh.mock.token", + refresh_token=7654, ) mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) strategy = rest.OAuthCredentialsStrategy( @@ -385,7 +385,7 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): new_token = await strategy.acquire(mock_rest) mock_rest.refresh_access_token.assert_awaited_once_with( - client=123456789, client_secret="123123123", refresh_token="refresh.mock.token" + client=123456789, client_secret="123123123", refresh_token=7654 ) assert new_token != token assert new_token == "mockmock.tokentoken.mocktoken" @@ -429,7 +429,7 @@ async def test_acquire_after_invalidation(self, mock_token): expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="okokok.fofdsasdasdofo.ddd", - refresh_token="refresh.mock.token", + refresh_token=7654, ) mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) strategy = rest.OAuthCredentialsStrategy( From c476d12668e44986af2e5f2c4eb907f273d2e9d9 Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 17:23:17 -0600 Subject: [PATCH 07/10] Mypy Typing Fixed --- hikari/impl/rest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 40ac82a051..9be7b67eaf 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -264,8 +264,8 @@ def __init__( self._lock = asyncio.Lock() self._scopes = scopes self._token: typing.Optional[str] = None - self._refresh_token = None - self._auth_code = auth_code + self._refresh_token: str = "" + self._auth_code: typing.Optional[str] = auth_code self._redirect_uri = redirect_uri @property @@ -321,15 +321,15 @@ async def acquire(self, client: rest_api.RESTClient) -> str: # Expires in is lowered a bit in-order to lower the chance of a dead token being used. self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99) - self._token = str(response.access_token) - self._refresh_token = response.refresh_token + self._token = response.access_token + self._refresh_token = str(response.refresh_token) return self._token def invalidate(self, token: typing.Optional[str] = None) -> None: if not token or token == self._token: self._expire_at = 0.0 self._token = None - self._refresh_token = None + self._refresh_token = "" self._auth_code = None From bbd5a1e9c4e135526f880421bc2e7a3ac95883b8 Mon Sep 17 00:00:00 2001 From: theonlydoublee Date: Fri, 24 Mar 2023 18:23:43 -0600 Subject: [PATCH 08/10] Made Changes From Review - Moved `_is_expired` after properties - Fixed mypy issues - Removed spec usage for Mock on mock_token - Moved strategy into a fixture --- hikari/impl/rest.py | 14 ++-- tests/hikari/impl/test_rest.py | 118 +++++++++------------------------ 2 files changed, 41 insertions(+), 91 deletions(-) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 9be7b67eaf..064f10e724 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -264,7 +264,7 @@ def __init__( self._lock = asyncio.Lock() self._scopes = scopes self._token: typing.Optional[str] = None - self._refresh_token: str = "" + self._refresh_token: typing.Optional[str] = None self._auth_code: typing.Optional[str] = auth_code self._redirect_uri = redirect_uri @@ -273,9 +273,6 @@ def client_id(self) -> snowflakes.Snowflake: """ID of the application this token strategy authenticates with.""" return self._client_id - def _is_expired(self) -> bool: - return time.monotonic() >= self._expire_at - @property def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]]: """Sequence of scopes this token strategy authenticates for.""" @@ -285,6 +282,9 @@ def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]] def token_type(self) -> applications.TokenType: return applications.TokenType.BEARER + def _is_expired(self) -> bool: + return time.monotonic() >= self._expire_at + async def acquire(self, client: rest_api.RESTClient) -> str: if not self._auth_code: raise RuntimeError("Token has been invalidated. Unable to get current or new token") @@ -310,7 +310,9 @@ async def acquire(self, client: rest_api.RESTClient) -> str: ) else: response = await client.refresh_access_token( - client=self._client_id, client_secret=self._client_secret, refresh_token=self._refresh_token + client=self._client_id, + client_secret=self._client_secret, + refresh_token=str(self._refresh_token), ) except errors.ClientHTTPResponseError as exc: @@ -329,7 +331,7 @@ def invalidate(self, token: typing.Optional[str] = None) -> None: if not token or token == self._token: self._expire_at = 0.0 self._token = None - self._refresh_token = "" + self._refresh_token = None self._auth_code = None diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index b86cf6079d..00e1d67b37 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -302,11 +302,20 @@ class TestOAuthCredentialsStrategy: @pytest.fixture() def mock_token(self): return mock.Mock( - applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="mockmock.tokentoken.mocktoken", - refresh_token=7654, + refresh_token="7654", + ) + + @pytest.fixture() + def strategy(self): + return rest.OAuthCredentialsStrategy( + client=4321, + client_secret="123123123", + auth_code="auth#code", + redirect_uri="https://web.site/auth/discord", + scopes=("identify",), ) def test_client_id_property(self): @@ -320,63 +329,38 @@ def test_client_id_property(self): assert strategy.client_id == 41551 - def test_scopes_property(self): - token = rest.OAuthCredentialsStrategy( - client=987654321, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - scopes=("identify",), - ) - - assert token.scopes == ("identify",) + def test_scopes_property(self, strategy): + assert strategy.scopes == ("identify",) - def test_token_type_property(self): - token = rest.OAuthCredentialsStrategy( - client=987654321, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) - assert token.token_type is applications.TokenType.BEARER + def test_token_type_property(self, strategy): + assert strategy.token_type is applications.TokenType.BEARER @pytest.mark.asyncio() - async def test_acquire_on_new_instance(self, mock_token): + async def test_acquire_on_new_instance(self, mock_token, strategy): mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - result = await rest.OAuthCredentialsStrategy( - client=987654321, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - scopes=("identify",), - ).acquire(mock_rest) + result = await strategy.acquire(mock_rest) assert result == "mockmock.tokentoken.mocktoken" mock_rest.authorize_access_token.assert_awaited_once_with( - client=987654321, + client=4321, client_secret="123123123", code="auth#code", redirect_uri="https://web.site/auth/discord", ) @pytest.mark.asyncio() - async def test_acquire_handles_out_of_date_token(self, mock_token): + async def test_acquire_handles_out_of_date_token(self, mock_token, strategy): mock_old_token = mock.AsyncMock( applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="old.mock.token", - refresh_token=7654, + refresh_token="7654", ) mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + token = await strategy.acquire( mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) ) @@ -385,7 +369,7 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): new_token = await strategy.acquire(mock_rest) mock_rest.refresh_access_token.assert_awaited_once_with( - client=123456789, client_secret="123123123", refresh_token=7654 + client=4321, client_secret="123123123", refresh_token="7654" ) assert new_token != token assert new_token == "mockmock.tokentoken.mocktoken" @@ -397,7 +381,7 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc with mock.patch.object(asyncio, "Lock", return_value=lock): strategy = rest.OAuthCredentialsStrategy( - client=123456789, + client=4321, client_secret="123123123", auth_code="auth#code", redirect_uri="https://web.site/auth/discord", @@ -411,7 +395,7 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc results = await tokens_gather mock_rest.authorize_access_token.assert_awaited_once_with( - client=123456789, + client=4321, client_secret="123123123", code="auth#code", redirect_uri="https://web.site/auth/discord", @@ -423,22 +407,16 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc ] @pytest.mark.asyncio() - async def test_acquire_after_invalidation(self, mock_token): + async def test_acquire_after_invalidation(self, mock_token, strategy): mock_old_token = mock.AsyncMock( applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), token_type=applications.TokenType.BEARER, access_token="okokok.fofdsasdasdofo.ddd", - refresh_token=7654, + refresh_token="7654", ) mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - scopes=("identify",), - ) + token = await strategy.acquire( mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) ) @@ -448,7 +426,7 @@ async def test_acquire_after_invalidation(self, mock_token): await strategy.acquire(mock_rest) @pytest.mark.asyncio() - async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self): + async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self, strategy): class MockLock: def __init__(self, strat): self._strategy = strat @@ -461,12 +439,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): return mock_rest = mock.AsyncMock() - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) strategy._lock = MockLock(strategy) strategy._token = None strategy._expire_at = time.monotonic() + 500 @@ -478,18 +450,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): mock_rest.authorize_access_token.assert_not_called() @pytest.mark.asyncio() - async def test_acquire_caches_client_http_response_error(self): + async def test_acquire_caches_client_http_response_error(self, strategy): mock_rest = mock.AsyncMock() error = errors.ClientHTTPResponseError( url="okokok", status=42, headers={}, raw_body=b"ok", message="OK", code=34123 ) mock_rest.authorize_access_token.side_effect = error - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) with pytest.raises(errors.ClientHTTPResponseError): await strategy.acquire(mock_rest) @@ -498,19 +464,13 @@ async def test_acquire_caches_client_http_response_error(self): await strategy.acquire(mock_rest) mock_rest.authorize_access_token.assert_awaited_once_with( - client=123456789, + client=4321, client_secret="123123123", code="auth#code", redirect_uri="https://web.site/auth/discord", ) - def test_invalidate_when_token_is_not_stored_token(self): - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + def test_invalidate_when_token_is_not_stored_token(self, strategy): strategy._expire_at = 10.0 strategy._token = "token" @@ -519,13 +479,7 @@ def test_invalidate_when_token_is_not_stored_token(self): assert strategy._expire_at == 10.0 assert strategy._token == "token" - def test_invalidate_when_no_token_specified(self): - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + def test_invalidate_when_no_token_specified(self, strategy): strategy._expire_at = 10.0 strategy._token = "token" @@ -534,13 +488,7 @@ def test_invalidate_when_no_token_specified(self): assert strategy._expire_at == 0.0 assert strategy._token is None - def test_invalidate_when_token_is_stored_token(self): - strategy = rest.OAuthCredentialsStrategy( - client=123456789, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + def test_invalidate_when_token_is_stored_token(self, strategy): strategy._expire_at = 10.0 strategy._token = "token" From f91b3c9c41f21df8a4c3595117663d82294ad9bf Mon Sep 17 00:00:00 2001 From: davfsa Date: Sat, 25 Mar 2023 10:21:59 +0100 Subject: [PATCH 09/10] Last touch-ups - Remove unnecessary lock in tests - Remove `scopes` argument (Discord doesn't even document it any more) - Fix typehint for refresh_token --- changes/1558.breaking.md | 1 + hikari/api/rest.py | 19 ++--- hikari/applications.py | 2 +- hikari/impl/rest.py | 44 +++-------- tests/hikari/impl/test_rest.py | 138 +++++++++------------------------ 5 files changed, 57 insertions(+), 147 deletions(-) create mode 100644 changes/1558.breaking.md diff --git a/changes/1558.breaking.md b/changes/1558.breaking.md new file mode 100644 index 0000000000..39df655213 --- /dev/null +++ b/changes/1558.breaking.md @@ -0,0 +1 @@ +Remove incorrect `scopes` parameter from `RESTClient.refresh_access_token`. diff --git a/hikari/api/rest.py b/hikari/api/rest.py index e638d8f649..d4f7540261 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -2960,6 +2960,12 @@ async def authorize_access_token( ) -> applications.OAuth2AuthorizationToken: """Authorize an OAuth2 token using the authorize code grant type. + .. warning:: + There is no way to ensure what scopes are granted in the token, + so you should check + `hikari.applications.OAuth2AuthorizationToken.scopes` to validate + that the expected scopes were actually authorized here. + Parameters ---------- client : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication] @@ -2995,16 +3001,12 @@ async def refresh_access_token( client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, refresh_token: str, - *, - scopes: undefined.UndefinedOr[ - typing.Sequence[typing.Union[applications.OAuth2Scope, str]] - ] = undefined.UNDEFINED, ) -> applications.OAuth2AuthorizationToken: """Refresh an access token. .. warning:: - As of writing this Discord currently ignores any passed scopes, - therefore you should use + There is no way to ensure what scopes are granted in the token, + so you should check `hikari.applications.OAuth2AuthorizationToken.scopes` to validate that the expected scopes were actually authorized here. @@ -3017,11 +3019,6 @@ async def refresh_access_token( refresh_token : str The refresh token to use. - Other Parameters - ---------------- - scopes : typing.Sequence[typing.Union[hikari.applications.OAuth2Scope, str]] - The scope of the access request. - Returns ------- hikari.applications.OAuth2AuthorizationToken diff --git a/hikari/applications.py b/hikari/applications.py index d7a3f17197..3da5f03b3e 100644 --- a/hikari/applications.py +++ b/hikari/applications.py @@ -745,7 +745,7 @@ def __str__(self) -> str: class OAuth2AuthorizationToken(PartialOAuth2Token): """Model for the OAuth2 token data returned by the authorization grant flow.""" - refresh_token: int = attr.field(eq=False, hash=False, repr=False) + refresh_token: str = attr.field(eq=False, hash=False, repr=False) """Refresh token used to obtain new access tokens with the same grant.""" webhook: typing.Optional[webhooks.IncomingWebhook] = attr.field(eq=False, hash=False, repr=True) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 064f10e724..a3813ff570 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -224,15 +224,10 @@ class OAuthCredentialsStrategy(rest_api.TokenStrategy): authorize as. client_secret : str Client secret to use when authorizing. - auth_code : str - Auth code given from Discord when user authorizes + code : str + The authorization code to exchange for an OAuth2 access token. redirect_uri: str - The redirect uri that was included in the authorization request - - Other Parameters - ---------------- - scopes : typing.Sequence[str] - The scopes to authorize for. + The redirect uri that was included in the authorization request. """ __slots__: typing.Sequence[str] = ( @@ -241,9 +236,8 @@ class OAuthCredentialsStrategy(rest_api.TokenStrategy): "_exception", "_expire_at", "_lock", - "_scopes", "_token", - "_auth_code", + "_code", "_redirect_uri", "_refresh_token", ) @@ -252,20 +246,17 @@ def __init__( self, client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, - auth_code: str, + code: str, redirect_uri: str, - *, - scopes: typing.Sequence[typing.Union[applications.OAuth2Scope, str]] = (applications.OAuth2Scope.IDENTIFY,), ) -> None: self._client_id = snowflakes.Snowflake(client) self._client_secret = client_secret self._exception: typing.Optional[errors.ClientHTTPResponseError] = None self._expire_at = 0.0 self._lock = asyncio.Lock() - self._scopes = scopes self._token: typing.Optional[str] = None self._refresh_token: typing.Optional[str] = None - self._auth_code: typing.Optional[str] = auth_code + self._code: typing.Optional[str] = code self._redirect_uri = redirect_uri @property @@ -273,11 +264,6 @@ def client_id(self) -> snowflakes.Snowflake: """ID of the application this token strategy authenticates with.""" return self._client_id - @property - def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]]: - """Sequence of scopes this token strategy authenticates for.""" - return self._scopes - @property def token_type(self) -> applications.TokenType: return applications.TokenType.BEARER @@ -286,7 +272,7 @@ def _is_expired(self) -> bool: return time.monotonic() >= self._expire_at async def acquire(self, client: rest_api.RESTClient) -> str: - if not self._auth_code: + if not self._code: raise RuntimeError("Token has been invalidated. Unable to get current or new token") if self._token and not self._is_expired(): @@ -305,14 +291,15 @@ async def acquire(self, client: rest_api.RESTClient) -> str: response = await client.authorize_access_token( client=self._client_id, client_secret=self._client_secret, - code=self._auth_code, + code=self._code, redirect_uri=self._redirect_uri, ) else: + assert self._refresh_token is not None response = await client.refresh_access_token( client=self._client_id, client_secret=self._client_secret, - refresh_token=str(self._refresh_token), + refresh_token=self._refresh_token, ) except errors.ClientHTTPResponseError as exc: @@ -324,7 +311,7 @@ async def acquire(self, client: rest_api.RESTClient) -> str: # Expires in is lowered a bit in-order to lower the chance of a dead token being used. self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99) self._token = response.access_token - self._refresh_token = str(response.refresh_token) + self._refresh_token = response.refresh_token return self._token def invalidate(self, token: typing.Optional[str] = None) -> None: @@ -332,7 +319,7 @@ def invalidate(self, token: typing.Optional[str] = None) -> None: self._expire_at = 0.0 self._token = None self._refresh_token = None - self._auth_code = None + self._code = None class _RESTProvider(traits.RESTAware): @@ -2332,19 +2319,12 @@ async def refresh_access_token( client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, refresh_token: str, - *, - scopes: undefined.UndefinedOr[ - typing.Sequence[typing.Union[applications.OAuth2Scope, str]] - ] = undefined.UNDEFINED, ) -> applications.OAuth2AuthorizationToken: route = routes.POST_TOKEN.compile() form_builder = data_binding.URLEncodedFormBuilder() form_builder.add_field("grant_type", "refresh_token") form_builder.add_field("refresh_token", refresh_token) - if scopes is not undefined.UNDEFINED: - form_builder.add_field("scope", " ".join(scopes)) - response = await self._request( route, form_builder=form_builder, auth=self._gen_oauth2_token(client, client_secret) ) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 00e1d67b37..f1864f1f65 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -63,6 +63,14 @@ from hikari.internal import time from tests.hikari import hikari_test_helpers + +class StubModel(snowflakes.Unique): + id = None + + def __init__(self, id=0): + self.id = snowflakes.Snowflake(id) + + ################# # _RESTProvider # ################# @@ -173,18 +181,12 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): @pytest.mark.asyncio() async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): - lock = asyncio.Lock() - mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(side_effect=[mock_token])) - - with mock.patch.object(asyncio, "Lock", return_value=lock): - strategy = rest.ClientCredentialsStrategy(client=6512312, client_secret="453123123") - - async with lock: - tokens_gather = asyncio.gather( - strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) - ) + mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(return_value=mock_token)) + strategy = rest.ClientCredentialsStrategy(client=6512312, client_secret="453123123") - results = await tokens_gather + results = await asyncio.gather( + strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) + ) mock_rest.authorize_client_credentials_token.assert_awaited_once_with( client=6512312, client_secret="453123123", scopes=("applications.commands.update", "identify") @@ -313,35 +315,28 @@ def strategy(self): return rest.OAuthCredentialsStrategy( client=4321, client_secret="123123123", - auth_code="auth#code", + code="auth#code", redirect_uri="https://web.site/auth/discord", - scopes=("identify",), ) def test_client_id_property(self): - mock_client = hikari_test_helpers.mock_class_namespace(applications.Application, id=41551, init_=False)() strategy = rest.OAuthCredentialsStrategy( - client=mock_client, + client=StubModel(41551), client_secret="123123123", - auth_code="auth#code", + code="auth#code", redirect_uri="https://web.site/auth/discord", ) assert strategy.client_id == 41551 - def test_scopes_property(self, strategy): - assert strategy.scopes == ("identify",) - def test_token_type_property(self, strategy): assert strategy.token_type is applications.TokenType.BEARER @pytest.mark.asyncio() async def test_acquire_on_new_instance(self, mock_token, strategy): - mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - - result = await strategy.acquire(mock_rest) + mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - assert result == "mockmock.tokentoken.mocktoken" + assert await strategy.acquire(mock_rest) == "mockmock.tokentoken.mocktoken" mock_rest.authorize_access_token.assert_awaited_once_with( client=4321, @@ -352,47 +347,26 @@ async def test_acquire_on_new_instance(self, mock_token, strategy): @pytest.mark.asyncio() async def test_acquire_handles_out_of_date_token(self, mock_token, strategy): - mock_old_token = mock.AsyncMock( - applications.PartialOAuth2Token, - expires_in=datetime.timedelta(weeks=1), - token_type=applications.TokenType.BEARER, - access_token="old.mock.token", - refresh_token="7654", - ) - mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) + mock_rest = mock.Mock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) + strategy._expire_at = 0 + strategy._token = "old token" + strategy._refresh_token = "refresh token" - token = await strategy.acquire( - mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) - ) - - with mock.patch.object(time, "monotonic", return_value=99999999999): - new_token = await strategy.acquire(mock_rest) + new_token = await strategy.acquire(mock_rest) mock_rest.refresh_access_token.assert_awaited_once_with( - client=4321, client_secret="123123123", refresh_token="7654" + client=4321, client_secret="123123123", refresh_token="refresh token" ) - assert new_token != token - assert new_token == "mockmock.tokentoken.mocktoken" - @pytest.mark.asyncio() - async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): - lock = asyncio.Lock() - mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(side_effect=[mock_token])) - - with mock.patch.object(asyncio, "Lock", return_value=lock): - strategy = rest.OAuthCredentialsStrategy( - client=4321, - client_secret="123123123", - auth_code="auth#code", - redirect_uri="https://web.site/auth/discord", - ) + assert new_token == strategy._token == "mockmock.tokentoken.mocktoken" - async with lock: - tokens_gather = asyncio.gather( - strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) - ) + @pytest.mark.asyncio() + async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token, strategy): + mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) - results = await tokens_gather + results = await asyncio.gather( + strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) + ) mock_rest.authorize_access_token.assert_awaited_once_with( client=4321, @@ -408,22 +382,10 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc @pytest.mark.asyncio() async def test_acquire_after_invalidation(self, mock_token, strategy): - mock_old_token = mock.AsyncMock( - applications.PartialOAuth2Token, - expires_in=datetime.timedelta(weeks=1), - token_type=applications.TokenType.BEARER, - access_token="okokok.fofdsasdasdofo.ddd", - refresh_token="7654", - ) - mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) + strategy._code = None - token = await strategy.acquire( - mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token)) - ) - - strategy.invalidate(token) - with pytest.raises(RuntimeError): - await strategy.acquire(mock_rest) + with pytest.raises(RuntimeError, match=r"Token has been invalidated. Unable to get current or new token"): + await strategy.acquire(object()) @pytest.mark.asyncio() async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self, strategy): @@ -438,7 +400,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): return - mock_rest = mock.AsyncMock() + mock_rest = mock.Mock() strategy._lock = MockLock(strategy) strategy._token = None strategy._expire_at = time.monotonic() + 500 @@ -675,13 +637,6 @@ def file_resource_patch(file_resource): yield resource -class StubModel(snowflakes.Unique): - id = None - - def __init__(self, id=0): - self.id = snowflakes.Snowflake(id) - - class TestStringifyHttpMessage: def test_when_body_is_None(self, rest_client): headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} @@ -3963,29 +3918,6 @@ async def test_refresh_access_token_without_scopes(self, rest_client): expected_route, form_builder=mock_url_encoded_form, auth="Basic NDU0MTIzOjEyMzEyMw==" ) - async def test_refresh_access_token_with_scopes(self, rest_client): - expected_route = routes.POST_TOKEN.compile() - mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): - result = await rest_client.refresh_access_token(54123, "312312", "a.codett", scopes=["1", "3", "scope43"]) - - mock_url_encoded_form.add_field.assert_has_calls( - [ - mock.call("grant_type", "refresh_token"), - mock.call("refresh_token", "a.codett"), - mock.call("scope", "1 3 scope43"), - ] - ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( - expected_route, form_builder=mock_url_encoded_form, auth="Basic NTQxMjM6MzEyMzEy" - ) - async def test_revoke_access_token(self, rest_client): expected_route = routes.POST_TOKEN_REVOKE.compile() mock_url_encoded_form = mock.Mock() From b71dc812a02c00834283a09ae94e1f8b5f0afb7f Mon Sep 17 00:00:00 2001 From: davfsa Date: Sat, 25 Mar 2023 18:03:56 +0100 Subject: [PATCH 10/10] Make codespell happy --- tests/hikari/impl/test_rest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index f1864f1f65..d9fa59fdda 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -390,8 +390,8 @@ async def test_acquire_after_invalidation(self, mock_token, strategy): @pytest.mark.asyncio() async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self, strategy): class MockLock: - def __init__(self, strat): - self._strategy = strat + def __init__(self, strategy): + self._strategy = strategy async def __aenter__(self): self._strategy._token = "cab.cab.cab"