diff --git a/pyfcm/baseapi.py b/pyfcm/baseapi.py index d2a598b..2092db4 100644 --- a/pyfcm/baseapi.py +++ b/pyfcm/baseapi.py @@ -1,5 +1,6 @@ # from __future__ import annotations +from functools import cached_property import json import time import threading @@ -8,9 +9,7 @@ from requests.adapters import HTTPAdapter from urllib3 import Retry -from google.oauth2 import service_account -from google.oauth2.credentials import Credentials -import google.auth.transport.requests +from google.auth.credentials import Credentials from pyfcm.errors import ( AuthenticationError, @@ -19,6 +18,7 @@ FCMServerError, FCMNotRegisteredError, ) +from pyfcm.token_manager import TokenManager # Migration to v1 - https://firebase.google.com/docs/cloud-messaging/migrate-v1 @@ -41,21 +41,17 @@ def __init__( Attributes: service_account_file (str): path to service account JSON file project_id (str): project ID of Google account - credentials (Credentials): Google oauth2 credentials instance, such as ADC + credentials (Credentials): Google auth credentials instance, such as ADC, service account one proxy_dict (dict): proxy settings dictionary, use proxy (keys: `http`, `https`) env (dict): environment settings dictionary, for example "app_engine" json_encoder (BaseJSONEncoder): JSON encoder adapter (BaseAdapter): adapter instance """ - if not (service_account_file or credentials): - raise AuthenticationError( - "Please provide a service account file path or credentials in the constructor" - ) - - self._service_account_file = service_account_file - self._fcm_end_point = None - self._project_id = project_id - self.credentials = credentials + self.token_manager = TokenManager( + service_account_file=service_account_file, + project_id=project_id, + credentials=credentials, + ) self.custom_adapter = adapter self.thread_local = threading.local() @@ -76,22 +72,11 @@ def __init__( self.json_encoder = json_encoder - @property + @cached_property def fcm_end_point(self) -> str: - if self._fcm_end_point is not None: - return self._fcm_end_point - if self.credentials is None: - self._initialize_credentials() - # prefer the project ID scoped to the supplied credentials. - # If, for some reason, the credentials do not specify a project id, - # we'll check for an explicitly supplied one, and raise an error otherwise - project_id = getattr(self.credentials, "project_id", None) or self._project_id - if not project_id: - raise AuthenticationError( - "Please provide a project_id either explicitly or through Google credentials." - ) - self._fcm_end_point = self.FCM_END_POINT_BASE + f"/{project_id}/messages:send" - return self._fcm_end_point + return ( + self.FCM_END_POINT_BASE + f"/{self.token_manager.project_id}/messages:send" + ) @property def requests_session(self): @@ -105,12 +90,9 @@ def requests_session(self): self.thread_local.requests_session = requests.Session() self.thread_local.requests_session.mount("http://", adapter) self.thread_local.requests_session.mount("https://", adapter) - self.thread_local.token_expiry = 0 - current_timestamp = time.time() - if self.thread_local.token_expiry < current_timestamp: - self.thread_local.requests_session.headers.update(self.request_headers()) - self.thread_local.token_expiry = current_timestamp + 1800 + # Always update headers with current shared token + self.thread_local.requests_session.headers.update(self.request_headers()) return self.thread_local.requests_session def send_request(self, payload=None, timeout=None): @@ -126,7 +108,7 @@ def send_request(self, payload=None, timeout=None): return self.send_request(payload, timeout) if self._is_access_token_expired(response): - self.thread_local.token_expiry = 0 + self.token_manager.refresh_token_if_expired() return self.send_request(payload, timeout) return response @@ -171,35 +153,6 @@ def _is_access_token_expired(self, response): return False - def _initialize_credentials(self): - """ - Initialize credentials and FCM endpoint if not already initialized. - """ - if self.credentials is None: - self.credentials = service_account.Credentials.from_service_account_file( - self._service_account_file, - scopes=["https://www.googleapis.com/auth/firebase.messaging"], - ) - self._service_account_file = None - - def _get_access_token(self): - """ - Generates access token from credentials. - If token expires then new access token is generated. - Returns: - str: Access token - """ - if self.credentials is None: - self._initialize_credentials() - - # get OAuth 2.0 access token - try: - request = google.auth.transport.requests.Request() - self.credentials.refresh(request) - return self.credentials.token - except Exception as e: - raise InvalidDataError(e) - def request_headers(self): """ Generates request headers including Content-Type and Authorization of Bearer token @@ -209,7 +162,7 @@ def request_headers(self): """ return { "Content-Type": "application/json", - "Authorization": "Bearer " + self._get_access_token(), + "Authorization": "Bearer " + self.token_manager.get_access_token(), } def json_dumps(self, data): diff --git a/pyfcm/token_manager.py b/pyfcm/token_manager.py new file mode 100644 index 0000000..f9d0d33 --- /dev/null +++ b/pyfcm/token_manager.py @@ -0,0 +1,151 @@ +from functools import cached_property +import threading +from datetime import datetime, timedelta, timezone +from typing import Optional + +from google.oauth2 import service_account +from google.auth.credentials import Credentials +import google.auth.transport.requests + +from pyfcm.errors import AuthenticationError, InvalidDataError + + +class TokenManager: + """ + Token management class extracted from BaseAPI. + Handles authentication credentials and access token lifecycle. + """ + + def __init__( + self, + service_account_file: Optional[str] = None, + project_id: Optional[str] = None, + credentials: Optional[Credentials] = None, + ): + """ + Initialize TokenManager + + Args: + service_account_file (str): path to service account JSON file + project_id (str): project ID of Google account + credentials (Credentials): Google auth credentials instance + """ + if not (service_account_file or credentials): + raise AuthenticationError( + "Please provide a service account file path or credentials in the constructor" + ) + + self._service_account_file = service_account_file + self._project_id = project_id + self._provided_credentials = credentials + + # Shared token management across threads + self._shared_token = None + self._token_lock = threading.RLock() + + @cached_property + def _credentials(self) -> Credentials: + """ + Get authentication credentials + + Returns: + Credentials: Google authentication credentials + """ + if self._provided_credentials is not None: + return self._provided_credentials + + credentials = service_account.Credentials.from_service_account_file( + self._service_account_file, + scopes=["https://www.googleapis.com/auth/firebase.messaging"], + ) + # Service account credentials has project_id (others are not) + self._project_id = credentials.project_id or self._project_id + self._service_account_file = None + return credentials + + @cached_property + def project_id(self) -> str: + """ + Get project ID + + Returns: + str: Project ID + + Raises: + RuntimeError: If project_id is not configured + """ + # Read credentials to resolve project_id if needed + _ = self._credentials + if self._project_id is None: + raise RuntimeError( + "Please provide a project_id either explicitly or through Google credentials." + ) + return self._project_id + + def _is_token_valid(self) -> bool: + """ + Enhanced token validity check with fallback mechanisms. + Combines expired property check with time-based validation. + + Returns: + bool: True if token is valid, False otherwise + """ + if not self._shared_token: + return False + + if self._credentials.expired: + return False + + # Fallback check: time-based validation with 5-minute buffer + # This accounts for the 4-minute early expiration issue + if ( + hasattr(self._credentials, "expiry") + and self._credentials.expiry + and self._credentials.expiry + <= datetime.now(timezone.utc) + timedelta(minutes=5) + ): + return False + + return True + + def get_access_token(self) -> str: + """ + Thread-safe access token management with shared token across threads. + Uses double-checked locking pattern for performance with enhanced validation. + + Returns: + str: Access token + + Raises: + InvalidDataError: If token acquisition fails + """ + # First check without lock (performance optimization) + if self._is_token_valid(): + return self._shared_token + + # Acquire lock and check again (double-checked locking) + with self._token_lock: + if self._is_token_valid(): + return self._shared_token + + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + self._shared_token = self._credentials.token + return self._shared_token + except Exception as e: + raise InvalidDataError(e) + + def refresh_token_if_expired(self) -> None: + """ + Refresh token if needed + """ + with self._token_lock: + self._shared_token = None + if self._credentials: + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + except Exception: + # If refresh fails, let the next request handle it + pass diff --git a/tests/conftest.py b/tests/conftest.py index 9b0379d..4da949f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,24 +2,29 @@ from unittest.mock import AsyncMock import pytest +from google.auth.credentials import Credentials from pyfcm import FCMNotification from pyfcm.baseapi import BaseAPI -from google.auth.credentials import Credentials class DummyCredentials(Credentials): - def refresh(): - pass + def __init__(self): + self.token = "dummy_token" + self._expired = True + + def refresh(self, request): + self.token = "refreshed_dummy_token" + self._expired = False @property - def project_id(self): - return "test" + def expired(self): + return self._expired -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def push_service(): - return FCMNotification(credentials=DummyCredentials()) + return FCMNotification(credentials=DummyCredentials(), project_id="test") @pytest.fixture @@ -48,6 +53,6 @@ def mock_aiohttp_session(mocker): return mock_send -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def base_api(): - return BaseAPI(credentials=DummyCredentials()) + return BaseAPI(credentials=DummyCredentials(), project_id="test") diff --git a/tests/test_baseapi.py b/tests/test_baseapi.py index fb4a2dd..1f97c3b 100644 --- a/tests/test_baseapi.py +++ b/tests/test_baseapi.py @@ -1,5 +1,16 @@ import json -import time + +import pytest + + +def test_empty_project_id(base_api): + base_api.token_manager._project_id = None + with pytest.raises(RuntimeError) as e: + base_api.fcm_end_point + assert ( + str(e.value) + == "Please provide a project_id either explicitly or through Google credentials." + ) def test_json_dumps(base_api): @@ -46,7 +57,6 @@ def test_send_request_normal(base_api, mocker): base_api.thread_local = mocker.Mock() base_api.thread_local.requests_session = mock_session - base_api.thread_local.token_expiry = time.time() + 1000 # do result = base_api.send_request(payload="test_payload", timeout=30) @@ -73,7 +83,6 @@ def test_send_request_retry_after(base_api, mocker): base_api.thread_local = mocker.Mock() base_api.thread_local.requests_session = mock_session - base_api.thread_local.token_expiry = time.time() + 1000 # do result = base_api.send_request(payload="test_payload", timeout=30) @@ -118,11 +127,14 @@ def test_send_request_access_token_expired_retry(base_api, mocker): type(base_api), "requests_session", new_callable=mocker.PropertyMock ) mock_requests_session.return_value = mock_session + base_api.token_manager._shared_token = "dummy" + assert base_api.token_manager._shared_token is not None # do result = base_api.send_request(payload="test_payload", timeout=30) # check assert mock_session.post.call_count == 2 - assert base_api.thread_local.token_expiry == 0 + # token cleared, but not refreshed because request_session is mocked + assert base_api.token_manager._shared_token is None assert result == success_response diff --git a/tests/test_fcm.py b/tests/test_fcm.py index 1c1284d..4eb492a 100644 --- a/tests/test_fcm.py +++ b/tests/test_fcm.py @@ -12,9 +12,9 @@ def test_push_service_without_credentials(): def test_push_service_directly_passed_credentials(push_service): # We should infer the project ID/endpoint from credentials # without the need to explcitily pass it + push_service.token_manager._project_id = "abc123" assert push_service.fcm_end_point == ( - "https://fcm.googleapis.com/v1/projects/" - f"{push_service.credentials.project_id}/messages:send" + "https://fcm.googleapis.com/v1/projects/abc123/messages:send" )