diff --git a/okta/client.py b/okta/client.py index 065b32ebf..b0fa318b2 100644 --- a/okta/client.py +++ b/okta/client.py @@ -116,7 +116,7 @@ def __init__(self, user_config: dict = {}): client_config_setter = ConfigSetter() client_config_setter._apply_config({'client': user_config}) self._config = client_config_setter.get_config() - # Prune configuration to remove unnecesary fields + # Prune configuration to remove unnecessary fields self._config = client_config_setter._prune_config(self._config) # Validate configuration ConfigValidator(self._config) @@ -128,6 +128,7 @@ def __init__(self, user_config: dict = {}): self._client_id = None self._scopes = None self._private_key = None + self._oauth_token_renewal_offset = None # Determine which cache to use cache = NoOpCache() @@ -154,6 +155,7 @@ def __init__(self, user_config: dict = {}): self._client_id = self._config["client"]["clientId"] self._scopes = self._config["client"]["scopes"] self._private_key = self._config["client"]["privateKey"] + self._oauth_token_renewal_offset = self._config["client"]["oauthTokenRenewalOffset"] setup_logging(log_level=self._config["client"]["logging"]["logLevel"]) # Check if logging should be enabled diff --git a/okta/config/config_setter.py b/okta/config/config_setter.py index b7d8c5dbf..b2db40db0 100644 --- a/okta/config/config_setter.py +++ b/okta/config/config_setter.py @@ -37,7 +37,8 @@ class ConfigSetter(): }, "rateLimit": { "maxRetries": '' - } + }, + "oauthTokenRenewalOffset": '' }, "testing": { "testingDisableHttpsCheck": '' @@ -116,6 +117,7 @@ def _apply_default_values(self): self._config["client"]["rateLimit"] = { "maxRetries": 2 } + self._config["client"]["oauthTokenRenewalOffset"] = 5 self._config["testing"]["testingDisableHttpsCheck"] = False diff --git a/okta/config/config_validator.py b/okta/config/config_validator.py index 9d72900c9..2a0e000f5 100644 --- a/okta/config/config_validator.py +++ b/okta/config/config_validator.py @@ -45,9 +45,8 @@ def validate_config(self): self._validate_token( client.get('token', "")) elif client.get('authorizationMode') == "PrivateKey": - client_fields = ['clientId', 'scopes', 'privateKey'] - client_fields_values = [self._config.get( - 'client').get(field, "") for field in client_fields] + client_fields = ['clientId', 'scopes', 'privateKey', 'oauthTokenRenewalOffset'] + client_fields_values = [client.get(field, "") for field in client_fields] errors += self._validate_client_fields(*client_fields_values) else: # Not a valid authorization mode errors += [ @@ -61,7 +60,7 @@ def validate_config(self): f"See {REPO_URL} for usage") def _validate_client_fields(self, client_id, client_scopes, - client_private_key): + client_private_key, oauth_token_renewal_offset): client_fields_errors = [] # check client id @@ -77,6 +76,14 @@ def _validate_client_fields(self, client_id, client_scopes, if not (client_scopes and client_private_key): client_fields_errors.append(ERROR_MESSAGE_SCOPES_PK_MISSING) + # Validate oauthTokenRenewalOffset + if not oauth_token_renewal_offset: + client_fields_errors.append("oauthTokenRenewalOffset must be provided") + if not isinstance(oauth_token_renewal_offset, int): + client_fields_errors.append("oauthTokenRenewalOffset must be a valid integer") + if isinstance(oauth_token_renewal_offset, int) and oauth_token_renewal_offset < 0: + client_fields_errors.append("oauthTokenRenewalOffset must be a non-negative integer") + return client_fields_errors def _validate_token(self, token: str): diff --git a/okta/oauth.py b/okta/oauth.py index 002c11c71..cfe7716ac 100644 --- a/okta/oauth.py +++ b/okta/oauth.py @@ -1,3 +1,4 @@ +import time from urllib.parse import urlencode, quote from okta.jwt import JWT from okta.http_client import HTTPClient @@ -37,6 +38,14 @@ async def get_access_token(self): str, Exception: Tuple of the access token, error that was raised (if any) """ + + # Check if access token has expired or will expire soon + current_time = int(time.time()) + if self._access_token and hasattr(self, '_access_token_expiry_time'): + renewal_offset = self._config["client"]["oauthTokenRenewalOffset"] * 60 # Convert minutes to seconds + if current_time + renewal_offset >= self._access_token_expiry_time: + self.clear_access_token() + # Return token if already generated if self._access_token: return (self._access_token, None) @@ -82,6 +91,9 @@ async def get_access_token(self): # Otherwise set token and return it self._access_token = parsed_response["access_token"] + + # Set token expiry time + self._access_token_expiry_time = int(time.time()) + parsed_response["expires_in"] return (self._access_token, None) def clear_access_token(self): @@ -91,3 +103,4 @@ def clear_access_token(self): self._access_token = None self._request_executor._cache.delete("OKTA_ACCESS_TOKEN") self._request_executor._default_headers.pop("Authorization", None) + self._access_token_expiry_time = None diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2b87f13ee..a28e0e398 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -914,3 +914,17 @@ async def test_client_session(mocker): config = {'orgUrl': org_url, 'token': token} async with OktaClient(config) as client: assert isinstance(client._request_executor._http_client._session, aiohttp.ClientSession) + + +def test_client_initialization(): + config = { + "orgUrl": "https://dev-1dq2.okta.com/oauth2/default", + "authorizationMode": "PrivateKey", + "clientId": "valid-client-id", + "scopes": ["scope1", "scope2"], + "privateKey": "valid-private-key", + "token": "valid-token", + } + client = OktaClient(config) + assert client._config["client"]["orgUrl"] == "https://dev-1dq2.okta.com/oauth2/default" + assert client._config["client"]["clientId"] == "valid-client-id" \ No newline at end of file diff --git a/tests/unit/test_config_setter.py b/tests/unit/test_config_setter.py new file mode 100644 index 000000000..91b7c92f3 --- /dev/null +++ b/tests/unit/test_config_setter.py @@ -0,0 +1,11 @@ +from okta.config.config_setter import ConfigSetter + +""" +Testing Config Setter +""" + +def test_env_variable_application(monkeypatch): + config_setter = ConfigSetter() + config_setter._apply_default_values() + + assert config_setter._config["client"]["oauthTokenRenewalOffset"] == 5 diff --git a/tests/unit/test_config_validator.py b/tests/unit/test_config_validator.py new file mode 100644 index 000000000..ad4576085 --- /dev/null +++ b/tests/unit/test_config_validator.py @@ -0,0 +1,41 @@ +import pytest +from okta.config.config_validator import ConfigValidator + +""" +Testing Config Validator +""" + +def test_validate_config_valid(): + config = { + "client": { + "orgUrl": "https://example.okta.com", + "authorizationMode": "PrivateKey", + "clientId": "valid-client-id", + "scopes": ["scope1", "scope2"], + "privateKey": "valid-private-key", + "oauthTokenRenewalOffset": 5 + }, + "testing": { + "testingDisableHttpsCheck": False + } + } + validator = ConfigValidator(config) + assert validator.validate_config() is None + +def test_validate_config_invalid_org_url(): + config = { + "client": { + "orgUrl": "http://example.okta.com", + "authorizationMode": "PrivateKey", + "clientId": "valid-client-id", + "scopes": ["scope1", "scope2"], + "privateKey": "valid-private-key", + "oauthTokenRenewalOffset": 5 + }, + "testing": { + "testingDisableHttpsCheck": False + } + } + with pytest.raises(ValueError) as excinfo: + ConfigValidator(config) + assert "must start with 'https'." in str(excinfo.value) \ No newline at end of file diff --git a/tests/unit/test_oauth.py b/tests/unit/test_oauth.py index 2e608c0da..933165fb7 100644 --- a/tests/unit/test_oauth.py +++ b/tests/unit/test_oauth.py @@ -2,6 +2,8 @@ import tests.mocks as mocks import os import pytest +from unittest.mock import AsyncMock, MagicMock +from okta.oauth import OAuth """ Testing Private Key Inputs @@ -39,3 +41,54 @@ def test_private_key_PEM_JWK_explicit_string(): def test_invalid_private_key_PEM_JWK(private_key): with pytest.raises(ValueError): generated_pem, generated_jwk = JWT.get_PEM_JWK(private_key) + + +@pytest.mark.asyncio +async def test_get_access_token(): + mock_request_executor = MagicMock() + mock_request_executor.create_request = AsyncMock(return_value=({"mock_request": "data"}, None)) + mock_response_details = MagicMock() + mock_response_details.content_type = "application/json" + mock_response_details.status = 200 + mock_request_executor.fire_request = AsyncMock( + return_value=(None, mock_response_details, '{"access_token": "mock_token", "expires_in": 3600}', None)) + + config = { + "client": { + "orgUrl": "https://example.okta.com", + "clientId": "valid-client-id", + "privateKey": mocks.SAMPLE_RSA, + "scopes": ["scope1", "scope2"], + "oauthTokenRenewalOffset": 5 + } + } + oauth = OAuth(mock_request_executor, config) + token, error = await oauth.get_access_token() + + assert token == "mock_token" + assert error is None + + +@pytest.mark.asyncio +async def test_clear_access_token(): + mock_request_executor = MagicMock() + mock_request_executor._cache = MagicMock() + mock_request_executor._default_headers = {} + + config = { + "client": { + "orgUrl": "https://example.okta.com", + "clientId": "valid-client-id", + "privateKey": "valid-private-key", + "scopes": ["scope1", "scope2"], + "oauthTokenRenewalOffset": 5 + } + } + oauth = OAuth(mock_request_executor, config) + oauth._access_token = "mock_token" + oauth._access_token_expiry_time = 1234567890 + + oauth.clear_access_token() + + assert oauth._access_token is None + assert oauth._access_token_expiry_time is None \ No newline at end of file