diff --git a/changes/1558.breaking.md b/changes/1558.breaking.md new file mode 100644 index 0000000000..39df655213 --- /dev/null +++ b/changes/1558.breaking.md @@ -0,0 +1 @@ +Remove incorrect `scopes` parameter from `RESTClient.refresh_access_token`. diff --git a/changes/1558.feature.md b/changes/1558.feature.md new file mode 100644 index 0000000000..d80eaf37c3 --- /dev/null +++ b/changes/1558.feature.md @@ -0,0 +1 @@ +Add `OAuthCredentialsStrategy` to `hikari.impl.rest` for OAuth2 flow tokens. diff --git a/hikari/api/rest.py b/hikari/api/rest.py index e638d8f649..d4f7540261 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -2960,6 +2960,12 @@ async def authorize_access_token( ) -> applications.OAuth2AuthorizationToken: """Authorize an OAuth2 token using the authorize code grant type. + .. warning:: + There is no way to ensure what scopes are granted in the token, + so you should check + `hikari.applications.OAuth2AuthorizationToken.scopes` to validate + that the expected scopes were actually authorized here. + Parameters ---------- client : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication] @@ -2995,16 +3001,12 @@ async def refresh_access_token( client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, refresh_token: str, - *, - scopes: undefined.UndefinedOr[ - typing.Sequence[typing.Union[applications.OAuth2Scope, str]] - ] = undefined.UNDEFINED, ) -> applications.OAuth2AuthorizationToken: """Refresh an access token. .. warning:: - As of writing this Discord currently ignores any passed scopes, - therefore you should use + There is no way to ensure what scopes are granted in the token, + so you should check `hikari.applications.OAuth2AuthorizationToken.scopes` to validate that the expected scopes were actually authorized here. @@ -3017,11 +3019,6 @@ async def refresh_access_token( refresh_token : str The refresh token to use. - Other Parameters - ---------------- - scopes : typing.Sequence[typing.Union[hikari.applications.OAuth2Scope, str]] - The scope of the access request. - Returns ------- hikari.applications.OAuth2AuthorizationToken diff --git a/hikari/applications.py b/hikari/applications.py index d7a3f17197..3da5f03b3e 100644 --- a/hikari/applications.py +++ b/hikari/applications.py @@ -745,7 +745,7 @@ def __str__(self) -> str: class OAuth2AuthorizationToken(PartialOAuth2Token): """Model for the OAuth2 token data returned by the authorization grant flow.""" - refresh_token: int = attr.field(eq=False, hash=False, repr=False) + refresh_token: str = attr.field(eq=False, hash=False, repr=False) """Refresh token used to obtain new access tokens with the same grant.""" webhook: typing.Optional[webhooks.IncomingWebhook] = attr.field(eq=False, hash=False, repr=True) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 5e1b42d31e..a3813ff570 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -214,6 +214,114 @@ def invalidate(self, token: typing.Optional[str]) -> None: self._token = None +class OAuthCredentialsStrategy(rest_api.TokenStrategy): + """Strategy class for handling OAuth2 authorization. + + Parameters + ---------- + client : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]] + Object or ID of the application this client credentials strategy should + authorize as. + client_secret : str + Client secret to use when authorizing. + code : str + The authorization code to exchange for an OAuth2 access token. + redirect_uri: str + The redirect uri that was included in the authorization request. + """ + + __slots__: typing.Sequence[str] = ( + "_client_id", + "_client_secret", + "_exception", + "_expire_at", + "_lock", + "_token", + "_code", + "_redirect_uri", + "_refresh_token", + ) + + def __init__( + self, + client: snowflakes.SnowflakeishOr[guilds.PartialApplication], + client_secret: str, + code: str, + redirect_uri: str, + ) -> None: + self._client_id = snowflakes.Snowflake(client) + self._client_secret = client_secret + self._exception: typing.Optional[errors.ClientHTTPResponseError] = None + self._expire_at = 0.0 + self._lock = asyncio.Lock() + self._token: typing.Optional[str] = None + self._refresh_token: typing.Optional[str] = None + self._code: typing.Optional[str] = code + self._redirect_uri = redirect_uri + + @property + def client_id(self) -> snowflakes.Snowflake: + """ID of the application this token strategy authenticates with.""" + return self._client_id + + @property + def token_type(self) -> applications.TokenType: + return applications.TokenType.BEARER + + def _is_expired(self) -> bool: + return time.monotonic() >= self._expire_at + + async def acquire(self, client: rest_api.RESTClient) -> str: + if not self._code: + raise RuntimeError("Token has been invalidated. Unable to get current or new token") + + if self._token and not self._is_expired(): + return self._token + + async with self._lock: + if self._token and not self._is_expired(): + return self._token + + if self._exception: + # If we don't copy the exception then python keeps adding onto the stack each time it's raised. + raise copy.copy(self._exception) from None + + try: + if not self._token: + response = await client.authorize_access_token( + client=self._client_id, + client_secret=self._client_secret, + code=self._code, + redirect_uri=self._redirect_uri, + ) + else: + assert self._refresh_token is not None + response = await client.refresh_access_token( + client=self._client_id, + client_secret=self._client_secret, + refresh_token=self._refresh_token, + ) + + except errors.ClientHTTPResponseError as exc: + if not isinstance(exc, errors.RateLimitTooLongError): + # If we don't copy the exception then python keeps adding onto the stack each time it's raised. + self._exception = copy.copy(exc) + raise + + # Expires in is lowered a bit in-order to lower the chance of a dead token being used. + self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99) + self._token = response.access_token + self._refresh_token = response.refresh_token + return self._token + + def invalidate(self, token: typing.Optional[str] = None) -> None: + if not token or token == self._token: + self._expire_at = 0.0 + self._token = None + self._refresh_token = None + self._code = None + + class _RESTProvider(traits.RESTAware): __slots__: typing.Sequence[str] = ("_entity_factory", "_executor", "_rest") @@ -2211,19 +2319,12 @@ async def refresh_access_token( client: snowflakes.SnowflakeishOr[guilds.PartialApplication], client_secret: str, refresh_token: str, - *, - scopes: undefined.UndefinedOr[ - typing.Sequence[typing.Union[applications.OAuth2Scope, str]] - ] = undefined.UNDEFINED, ) -> applications.OAuth2AuthorizationToken: route = routes.POST_TOKEN.compile() form_builder = data_binding.URLEncodedFormBuilder() form_builder.add_field("grant_type", "refresh_token") form_builder.add_field("refresh_token", refresh_token) - if scopes is not undefined.UNDEFINED: - form_builder.add_field("scope", " ".join(scopes)) - response = await self._request( route, form_builder=form_builder, auth=self._gen_oauth2_token(client, client_secret) ) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 1a062e5d67..d9fa59fdda 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -63,6 +63,14 @@ from hikari.internal import time from tests.hikari import hikari_test_helpers + +class StubModel(snowflakes.Unique): + id = None + + def __init__(self, id=0): + self.id = snowflakes.Snowflake(id) + + ################# # _RESTProvider # ################# @@ -173,18 +181,12 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): @pytest.mark.asyncio() async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): - lock = asyncio.Lock() - mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(side_effect=[mock_token])) - - with mock.patch.object(asyncio, "Lock", return_value=lock): - strategy = rest.ClientCredentialsStrategy(client=6512312, client_secret="453123123") - - async with lock: - tokens_gather = asyncio.gather( - strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) - ) + mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(return_value=mock_token)) + strategy = rest.ClientCredentialsStrategy(client=6512312, client_secret="453123123") - results = await tokens_gather + results = await asyncio.gather( + strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) + ) mock_rest.authorize_client_credentials_token.assert_awaited_once_with( client=6512312, client_secret="453123123", scopes=("applications.commands.update", "identify") @@ -293,6 +295,171 @@ def test_invalidate_when_token_is_stored_token(self): assert strategy._token is None +############################# +# OAuthCredentialsStrategy # +############################# + + +class TestOAuthCredentialsStrategy: + @pytest.fixture() + def mock_token(self): + return mock.Mock( + expires_in=datetime.timedelta(weeks=1), + token_type=applications.TokenType.BEARER, + access_token="mockmock.tokentoken.mocktoken", + refresh_token="7654", + ) + + @pytest.fixture() + def strategy(self): + return rest.OAuthCredentialsStrategy( + client=4321, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + def test_client_id_property(self): + strategy = rest.OAuthCredentialsStrategy( + client=StubModel(41551), + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + assert strategy.client_id == 41551 + + def test_token_type_property(self, strategy): + assert strategy.token_type is applications.TokenType.BEARER + + @pytest.mark.asyncio() + async def test_acquire_on_new_instance(self, mock_token, strategy): + mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) + + assert await strategy.acquire(mock_rest) == "mockmock.tokentoken.mocktoken" + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=4321, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + @pytest.mark.asyncio() + async def test_acquire_handles_out_of_date_token(self, mock_token, strategy): + mock_rest = mock.Mock(refresh_access_token=mock.AsyncMock(return_value=mock_token)) + strategy._expire_at = 0 + strategy._token = "old token" + strategy._refresh_token = "refresh token" + + new_token = await strategy.acquire(mock_rest) + + mock_rest.refresh_access_token.assert_awaited_once_with( + client=4321, client_secret="123123123", refresh_token="refresh token" + ) + + assert new_token == strategy._token == "mockmock.tokentoken.mocktoken" + + @pytest.mark.asyncio() + async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token, strategy): + mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token)) + + results = await asyncio.gather( + strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest) + ) + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=4321, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + assert results == [ + "mockmock.tokentoken.mocktoken", + "mockmock.tokentoken.mocktoken", + "mockmock.tokentoken.mocktoken", + ] + + @pytest.mark.asyncio() + async def test_acquire_after_invalidation(self, mock_token, strategy): + strategy._code = None + + with pytest.raises(RuntimeError, match=r"Token has been invalidated. Unable to get current or new token"): + await strategy.acquire(object()) + + @pytest.mark.asyncio() + async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self, strategy): + class MockLock: + def __init__(self, strategy): + self._strategy = strategy + + async def __aenter__(self): + self._strategy._token = "cab.cab.cab" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return + + mock_rest = mock.Mock() + strategy._lock = MockLock(strategy) + strategy._token = None + strategy._expire_at = time.monotonic() + 500 + + result = await strategy.acquire(mock_rest) + + assert result == "cab.cab.cab" + + mock_rest.authorize_access_token.assert_not_called() + + @pytest.mark.asyncio() + async def test_acquire_caches_client_http_response_error(self, strategy): + mock_rest = mock.AsyncMock() + error = errors.ClientHTTPResponseError( + url="okokok", status=42, headers={}, raw_body=b"ok", message="OK", code=34123 + ) + mock_rest.authorize_access_token.side_effect = error + + with pytest.raises(errors.ClientHTTPResponseError): + await strategy.acquire(mock_rest) + + with pytest.raises(errors.ClientHTTPResponseError): + await strategy.acquire(mock_rest) + + mock_rest.authorize_access_token.assert_awaited_once_with( + client=4321, + client_secret="123123123", + code="auth#code", + redirect_uri="https://web.site/auth/discord", + ) + + def test_invalidate_when_token_is_not_stored_token(self, strategy): + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate("tokena") + + assert strategy._expire_at == 10.0 + assert strategy._token == "token" + + def test_invalidate_when_no_token_specified(self, strategy): + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate(None) + + assert strategy._expire_at == 0.0 + assert strategy._token is None + + def test_invalidate_when_token_is_stored_token(self, strategy): + strategy._expire_at = 10.0 + strategy._token = "token" + + strategy.invalidate("token") + + assert strategy._expire_at == 0.0 + assert strategy._token is None + + ########### # RESTApp # ########### @@ -470,13 +637,6 @@ def file_resource_patch(file_resource): yield resource -class StubModel(snowflakes.Unique): - id = None - - def __init__(self, id=0): - self.id = snowflakes.Snowflake(id) - - class TestStringifyHttpMessage: def test_when_body_is_None(self, rest_client): headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} @@ -3758,29 +3918,6 @@ async def test_refresh_access_token_without_scopes(self, rest_client): expected_route, form_builder=mock_url_encoded_form, auth="Basic NDU0MTIzOjEyMzEyMw==" ) - async def test_refresh_access_token_with_scopes(self, rest_client): - expected_route = routes.POST_TOKEN.compile() - mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): - result = await rest_client.refresh_access_token(54123, "312312", "a.codett", scopes=["1", "3", "scope43"]) - - mock_url_encoded_form.add_field.assert_has_calls( - [ - mock.call("grant_type", "refresh_token"), - mock.call("refresh_token", "a.codett"), - mock.call("scope", "1 3 scope43"), - ] - ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( - expected_route, form_builder=mock_url_encoded_form, auth="Basic NTQxMjM6MzEyMzEy" - ) - async def test_revoke_access_token(self, rest_client): expected_route = routes.POST_TOKEN_REVOKE.compile() mock_url_encoded_form = mock.Mock()