Skip to content
1 change: 1 addition & 0 deletions changes/1558.breaking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove incorrect `scopes` parameter from `RESTClient.refresh_access_token`.
1 change: 1 addition & 0 deletions changes/1558.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `OAuthCredentialsStrategy` to `hikari.impl.rest` for OAuth2 flow tokens.
19 changes: 8 additions & 11 deletions hikari/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2960,6 +2960,12 @@ async def authorize_access_token(
) -> applications.OAuth2AuthorizationToken:
"""Authorize an OAuth2 token using the authorize code grant type.

.. warning::
There is no way to ensure what scopes are granted in the token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's seemingly no reason to put this note here since there isn't a scopes parameter? Also this change seems outside the scope of this PR

so you should check
`hikari.applications.OAuth2AuthorizationToken.scopes` to validate
that the expected scopes were actually authorized here.

Parameters
----------
client : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]
Expand Down Expand Up @@ -2995,16 +3001,12 @@ async def refresh_access_token(
client: snowflakes.SnowflakeishOr[guilds.PartialApplication],
client_secret: str,
refresh_token: str,
*,
scopes: undefined.UndefinedOr[
typing.Sequence[typing.Union[applications.OAuth2Scope, str]]
] = undefined.UNDEFINED,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed. This is a part of the Oauth2 spec and Discord does use this field these days.

) -> applications.OAuth2AuthorizationToken:
"""Refresh an access token.

.. warning::
As of writing this Discord currently ignores any passed scopes,
therefore you should use
There is no way to ensure what scopes are granted in the token,
so you should check
`hikari.applications.OAuth2AuthorizationToken.scopes` to validate
that the expected scopes were actually authorized here.

Expand All @@ -3017,11 +3019,6 @@ async def refresh_access_token(
refresh_token : str
The refresh token to use.

Other Parameters
----------------
scopes : typing.Sequence[typing.Union[hikari.applications.OAuth2Scope, str]]
The scope of the access request.

Returns
-------
hikari.applications.OAuth2AuthorizationToken
Expand Down
2 changes: 1 addition & 1 deletion hikari/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def __str__(self) -> str:
class OAuth2AuthorizationToken(PartialOAuth2Token):
"""Model for the OAuth2 token data returned by the authorization grant flow."""

refresh_token: int = attr.field(eq=False, hash=False, repr=False)
refresh_token: str = attr.field(eq=False, hash=False, repr=False)
"""Refresh token used to obtain new access tokens with the same grant."""

webhook: typing.Optional[webhooks.IncomingWebhook] = attr.field(eq=False, hash=False, repr=True)
Expand Down
115 changes: 108 additions & 7 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,114 @@ def invalidate(self, token: typing.Optional[str]) -> None:
self._token = None


class OAuthCredentialsStrategy(rest_api.TokenStrategy):
"""Strategy class for handling OAuth2 authorization.

Parameters
----------
client : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]]
Object or ID of the application this client credentials strategy should
authorize as.
client_secret : str
Client secret to use when authorizing.
code : str
The authorization code to exchange for an OAuth2 access token.
redirect_uri: str
The redirect uri that was included in the authorization request.
"""

__slots__: typing.Sequence[str] = (
"_client_id",
"_client_secret",
"_exception",
"_expire_at",
"_lock",
"_token",
"_code",
"_redirect_uri",
"_refresh_token",
)

def __init__(
self,
client: snowflakes.SnowflakeishOr[guilds.PartialApplication],
client_secret: str,
code: str,
redirect_uri: str,
) -> None:
self._client_id = snowflakes.Snowflake(client)
self._client_secret = client_secret
self._exception: typing.Optional[errors.ClientHTTPResponseError] = None
self._expire_at = 0.0
self._lock = asyncio.Lock()
self._token: typing.Optional[str] = None
self._refresh_token: typing.Optional[str] = None
self._code: typing.Optional[str] = code
self._redirect_uri = redirect_uri

@property
def client_id(self) -> snowflakes.Snowflake:
"""ID of the application this token strategy authenticates with."""
return self._client_id

@property
def token_type(self) -> applications.TokenType:
return applications.TokenType.BEARER

def _is_expired(self) -> bool:
return time.monotonic() >= self._expire_at

async def acquire(self, client: rest_api.RESTClient) -> str:
if not self._code:
raise RuntimeError("Token has been invalidated. Unable to get current or new token")

if self._token and not self._is_expired():
return self._token

async with self._lock:
if self._token and not self._is_expired():
return self._token

if self._exception:
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
raise copy.copy(self._exception) from None

try:
if not self._token:
response = await client.authorize_access_token(
client=self._client_id,
client_secret=self._client_secret,
code=self._code,
redirect_uri=self._redirect_uri,
)
else:
assert self._refresh_token is not None
response = await client.refresh_access_token(
client=self._client_id,
client_secret=self._client_secret,
refresh_token=self._refresh_token,
)

except errors.ClientHTTPResponseError as exc:
if not isinstance(exc, errors.RateLimitTooLongError):
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
self._exception = copy.copy(exc)
raise

# Expires in is lowered a bit in-order to lower the chance of a dead token being used.
self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99)
self._token = response.access_token
self._refresh_token = response.refresh_token
return self._token

def invalidate(self, token: typing.Optional[str] = None) -> None:
if not token or token == self._token:
self._expire_at = 0.0
self._token = None
self._refresh_token = None
self._code = None


class _RESTProvider(traits.RESTAware):
__slots__: typing.Sequence[str] = ("_entity_factory", "_executor", "_rest")

Expand Down Expand Up @@ -2211,19 +2319,12 @@ async def refresh_access_token(
client: snowflakes.SnowflakeishOr[guilds.PartialApplication],
client_secret: str,
refresh_token: str,
*,
scopes: undefined.UndefinedOr[
typing.Sequence[typing.Union[applications.OAuth2Scope, str]]
] = undefined.UNDEFINED,
) -> applications.OAuth2AuthorizationToken:
route = routes.POST_TOKEN.compile()
form_builder = data_binding.URLEncodedFormBuilder()
form_builder.add_field("grant_type", "refresh_token")
form_builder.add_field("refresh_token", refresh_token)

if scopes is not undefined.UNDEFINED:
form_builder.add_field("scope", " ".join(scopes))

response = await self._request(
route, form_builder=form_builder, auth=self._gen_oauth2_token(client, client_secret)
)
Expand Down
Loading