Skip to content
1 change: 1 addition & 0 deletions changes/1558.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `OAuthCredentialsStrategy` to `hikari.impl.rest` for OAuth2 flow tokens.
119 changes: 119 additions & 0 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,125 @@ def invalidate(self, token: typing.Optional[str]) -> None:
self._token = None


class OAuthCredentialsStrategy(rest_api.TokenStrategy):
"""Strategy class for handling OAuth2 authorization.

Parameters
----------
client : typing.Optional[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication]]
Object or ID of the application this client credentials strategy should
authorize as.
client_secret : str
Client secret to use when authorizing.
auth_code : str
Auth code given from Discord when user authorizes
redirect_uri: str
The redirect uri that was included in the authorization request

Other Parameters
----------------
scopes : typing.Sequence[str]
The scopes to authorize for.
"""

__slots__: typing.Sequence[str] = (
"_client_id",
"_client_secret",
"_exception",
"_expire_at",
"_lock",
"_scopes",
"_token",
"_auth_code",
"_redirect_uri",
"_refresh_token",
)

def __init__(
self,
client: snowflakes.SnowflakeishOr[guilds.PartialApplication],
client_secret: str,
auth_code: str,
redirect_uri: str,
*,
scopes: typing.Sequence[typing.Union[applications.OAuth2Scope, str]] = (applications.OAuth2Scope.IDENTIFY,),
) -> None:
self._client_id = snowflakes.Snowflake(client)
self._client_secret = client_secret
self._exception: typing.Optional[errors.ClientHTTPResponseError] = None
self._expire_at = 0.0
self._lock = asyncio.Lock()
self._scopes = scopes
self._token: typing.Optional[str] = None
self._refresh_token = None
self._auth_code = auth_code
self._redirect_uri = redirect_uri

@property
def client_id(self) -> snowflakes.Snowflake:
"""ID of the application this token strategy authenticates with."""
return self._client_id

def _is_expired(self) -> bool:
return time.monotonic() >= self._expire_at

@property
def scopes(self) -> typing.Sequence[typing.Union[applications.OAuth2Scope, str]]:
"""Sequence of scopes this token strategy authenticates for."""
return self._scopes

@property
def token_type(self) -> applications.TokenType:
return applications.TokenType.BEARER

async def acquire(self, client: rest_api.RESTClient) -> str:
if not self._auth_code:
raise RuntimeError("Token has been invalidated. Unable to get current or new token")

if self._token and not self._is_expired():
return self._token

async with self._lock:
if self._token and not self._is_expired():
return self._token

if self._exception:
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
raise copy.copy(self._exception) from None

try:
if not self._token:
response = await client.authorize_access_token(
client=self._client_id,
client_secret=self._client_secret,
code=self._auth_code,
redirect_uri=self._redirect_uri,
)
else:
response = await client.refresh_access_token(
client=self._client_id, client_secret=self._client_secret, refresh_token=self._refresh_token
)

except errors.ClientHTTPResponseError as exc:
if not isinstance(exc, errors.RateLimitTooLongError):
# If we don't copy the exception then python keeps adding onto the stack each time it's raised.
self._exception = copy.copy(exc)
raise

# Expires in is lowered a bit in-order to lower the chance of a dead token being used.
self._expire_at = time.monotonic() + math.floor(response.expires_in.total_seconds() * 0.99)
self._token = str(response.access_token)
self._refresh_token = response.refresh_token
return self._token

def invalidate(self, token: typing.Optional[str] = None) -> None:
if not token or token == self._token:
self._expire_at = 0.0
self._token = None
self._refresh_token = None
self._auth_code = None


class _RESTProvider(traits.RESTAware):
__slots__: typing.Sequence[str] = ("_entity_factory", "_executor", "_rest")

Expand Down
257 changes: 257 additions & 0 deletions tests/hikari/impl/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,263 @@ def test_invalidate_when_token_is_stored_token(self):
assert strategy._token is None


#############################
# OAuthCredentialsStrategy #
#############################


class TestOAuthCredentialsStrategy:
@pytest.fixture()
def mock_token(self):
return mock.Mock(
applications.PartialOAuth2Token,
expires_in=datetime.timedelta(weeks=1),
token_type=applications.TokenType.BEARER,
access_token="mockmock.tokentoken.mocktoken",
refresh_token=7654,
)

def test_client_id_property(self):
mock_client = hikari_test_helpers.mock_class_namespace(applications.Application, id=41551, init_=False)()
strategy = rest.OAuthCredentialsStrategy(
client=mock_client,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)

assert strategy.client_id == 41551

def test_scopes_property(self):
token = rest.OAuthCredentialsStrategy(
client=987654321,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
scopes=("identify",),
)

assert token.scopes == ("identify",)

def test_token_type_property(self):
token = rest.OAuthCredentialsStrategy(
client=987654321,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
assert token.token_type is applications.TokenType.BEARER

@pytest.mark.asyncio()
async def test_acquire_on_new_instance(self, mock_token):
mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_token))

result = await rest.OAuthCredentialsStrategy(
client=987654321,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
scopes=("identify",),
).acquire(mock_rest)

assert result == "mockmock.tokentoken.mocktoken"

mock_rest.authorize_access_token.assert_awaited_once_with(
client=987654321,
client_secret="123123123",
code="auth#code",
redirect_uri="https://web.site/auth/discord",
)

@pytest.mark.asyncio()
async def test_acquire_handles_out_of_date_token(self, mock_token):
mock_old_token = mock.AsyncMock(
applications.PartialOAuth2Token,
expires_in=datetime.timedelta(weeks=1),
token_type=applications.TokenType.BEARER,
access_token="old.mock.token",
refresh_token=7654,
)
mock_rest = mock.AsyncMock(refresh_access_token=mock.AsyncMock(return_value=mock_token))
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
token = await strategy.acquire(
mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token))
)

with mock.patch.object(time, "monotonic", return_value=99999999999):
new_token = await strategy.acquire(mock_rest)

mock_rest.refresh_access_token.assert_awaited_once_with(
client=123456789, client_secret="123123123", refresh_token=7654
)
assert new_token != token
assert new_token == "mockmock.tokentoken.mocktoken"

@pytest.mark.asyncio()
async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token):
lock = asyncio.Lock()
mock_rest = mock.AsyncMock(authorize_access_token=mock.AsyncMock(side_effect=[mock_token]))

with mock.patch.object(asyncio, "Lock", return_value=lock):
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)

async with lock:
tokens_gather = asyncio.gather(
strategy.acquire(mock_rest), strategy.acquire(mock_rest), strategy.acquire(mock_rest)
)

results = await tokens_gather

mock_rest.authorize_access_token.assert_awaited_once_with(
client=123456789,
client_secret="123123123",
code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
assert results == [
"mockmock.tokentoken.mocktoken",
"mockmock.tokentoken.mocktoken",
"mockmock.tokentoken.mocktoken",
]

@pytest.mark.asyncio()
async def test_acquire_after_invalidation(self, mock_token):
mock_old_token = mock.AsyncMock(
applications.PartialOAuth2Token,
expires_in=datetime.timedelta(weeks=1),
token_type=applications.TokenType.BEARER,
access_token="okokok.fofdsasdasdofo.ddd",
refresh_token=7654,
)
mock_rest = mock.Mock(authorize_access_token=mock.AsyncMock(return_value=mock_token))
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
scopes=("identify",),
)
token = await strategy.acquire(
mock.AsyncMock(authorize_access_token=mock.AsyncMock(return_value=mock_old_token))
)

strategy.invalidate(token)
with pytest.raises(RuntimeError):
await strategy.acquire(mock_rest)

@pytest.mark.asyncio()
async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self):
class MockLock:
def __init__(self, strat):
self._strategy = strat

async def __aenter__(self):
self._strategy._token = "cab.cab.cab"
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
return

mock_rest = mock.AsyncMock()
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
strategy._lock = MockLock(strategy)
strategy._token = None
strategy._expire_at = time.monotonic() + 500

result = await strategy.acquire(mock_rest)

assert result == "cab.cab.cab"

mock_rest.authorize_access_token.assert_not_called()

@pytest.mark.asyncio()
async def test_acquire_caches_client_http_response_error(self):
mock_rest = mock.AsyncMock()
error = errors.ClientHTTPResponseError(
url="okokok", status=42, headers={}, raw_body=b"ok", message="OK", code=34123
)
mock_rest.authorize_access_token.side_effect = error
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)

with pytest.raises(errors.ClientHTTPResponseError):
await strategy.acquire(mock_rest)

with pytest.raises(errors.ClientHTTPResponseError):
await strategy.acquire(mock_rest)

mock_rest.authorize_access_token.assert_awaited_once_with(
client=123456789,
client_secret="123123123",
code="auth#code",
redirect_uri="https://web.site/auth/discord",
)

def test_invalidate_when_token_is_not_stored_token(self):
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
strategy._expire_at = 10.0
strategy._token = "token"

strategy.invalidate("tokena")

assert strategy._expire_at == 10.0
assert strategy._token == "token"

def test_invalidate_when_no_token_specified(self):
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
strategy._expire_at = 10.0
strategy._token = "token"

strategy.invalidate(None)

assert strategy._expire_at == 0.0
assert strategy._token is None

def test_invalidate_when_token_is_stored_token(self):
strategy = rest.OAuthCredentialsStrategy(
client=123456789,
client_secret="123123123",
auth_code="auth#code",
redirect_uri="https://web.site/auth/discord",
)
strategy._expire_at = 10.0
strategy._token = "token"

strategy.invalidate("token")

assert strategy._expire_at == 0.0
assert strategy._token is None


###########
# RESTApp #
###########
Expand Down