Skip to content
2 changes: 2 additions & 0 deletions changes/1558.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add `UserCredentialsStrategy` to `hikari\impl\rest.py` similar to `ClientCredentialsStrategy` to allow simple OAuth2
token generation and refreshing
126 changes: 126 additions & 0 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,132 @@ def invalidate(self, token: typing.Optional[str]) -> None:
self._token = None


class UserCredentialsStrategy(rest_api.TokenStrategy):
"""Strategy class for handling client credential OAuth2 authorization.

Parameters
----------
client_id : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]]
Object or ID of the application this client credentials strategy should
authorize as.
client_secret : str
Client secret to use when authorizing.
auth_code : str
Auth code given from Discord when user authorizes
redirect_uri: str
The redirect uri that was included in the authorization request

Other Parameters
----------------
scopes : typing.Sequence[str]
The scopes to authorize for.
"""

__slots__: typing.Sequence[str] = (
"_client_id",
"_client_secret",
"_exception",
"_expire_at",
"_lock",
"_scopes",
"_token",
"_code",
"_redirect_uri",
"_refresh_token",
)

def __init__(
self,
client_id: snowflakes.SnowflakeishOr[guilds.PartialApplication],
client_secret: str,
auth_code: str,
redirect_uri: str,
*,
scopes: typing.Sequence[typing.Union[applications.OAuth2Scope, str]] = (applications.OAuth2Scope.IDENTIFY,),
) -> None:
self._client_id = snowflakes.Snowflake(client_id)
self._client_secret = client_secret
self._exception: typing.Optional[errors.ClientHTTPResponseError] = None
self._expire_at = 0.0
self._lock = asyncio.Lock()
self._scopes = scopes
self._token: typing.Optional[str] = None
self._refresh_token = None
self._code = auth_code
self._redirect_uri = redirect_uri

@property
def client_id(self) -> snowflakes.Snowflake:
"""ID of the application this token strategy authenticates with."""
return self._client_id

@property
def _is_expired(self) -> bool:
return time.monotonic() >= self._expire_at

@property
def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]]:
"""Sequence of scopes this token strategy authenticates for."""
return self._scopes

@property
def token_type(self) -> applications.TokenType:
return applications.TokenType.BEARER

async def acquire(self, client: rest_api.RESTClient | RESTApp) -> str:
if not self._code:
raise "Token has been invalidated. Unable to get current or new token"

if self._token and not self._is_expired:
return self._token

async with self._lock:
if self._token and not self._is_expired:
return self._token

if self._exception:
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
raise copy.copy(self._exception) from None

try:
if isinstance(client, RESTApp):
client = client.acquire()
client.start()

if not self._token:
response = await client.authorize_access_token(
client=self._client_id,
client_secret=self._client_secret,
code=self._code,
redirect_uri=self._redirect_uri,
)
else:
response = await client.refresh_access_token(
client=self._client_id, client_secret=self._client_secret, refresh_token=self._refresh_token
)

except errors.ClientHTTPResponseError as exc:
if not isinstance(exc, errors.RateLimitTooLongError):
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
self._exception = copy.copy(exc)
await client.close()
raise

# Expires in is lowered a bit in-order to lower the chance of a dead token being used.
self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99)
self._token = f"{response.access_token}"
self._refresh_token = response.refresh_token
await client.close()
return self._token

def invalidate(self, token: typing.Optional[str] = None) -> None:
if not token or token == self._token:
self._expire_at = 0.0
self._token = None
self._refresh_token = None
self._code = None


class _RESTProvider(traits.RESTAware):
__slots__: typing.Sequence[str] = ("_entity_factory", "_executor", "_rest")

Expand Down