diff --git a/pyfcm/baseapi.py b/pyfcm/baseapi.py index d2a598b..2843ff6 100644 --- a/pyfcm/baseapi.py +++ b/pyfcm/baseapi.py @@ -1,5 +1,6 @@ # from __future__ import annotations +from functools import cached_property import json import time import threading @@ -9,7 +10,7 @@ from urllib3 import Retry from google.oauth2 import service_account -from google.oauth2.credentials import Credentials +from google.auth.credentials import Credentials import google.auth.transport.requests from pyfcm.errors import ( @@ -41,7 +42,7 @@ def __init__( Attributes: service_account_file (str): path to service account JSON file project_id (str): project ID of Google account - credentials (Credentials): Google oauth2 credentials instance, such as ADC + credentials (Credentials): Google auth credentials instance, such as ADC, service account one proxy_dict (dict): proxy settings dictionary, use proxy (keys: `http`, `https`) env (dict): environment settings dictionary, for example "app_engine" json_encoder (BaseJSONEncoder): JSON encoder @@ -53,9 +54,8 @@ def __init__( ) self._service_account_file = service_account_file - self._fcm_end_point = None self._project_id = project_id - self.credentials = credentials + self._provided_credentials = credentials self.custom_adapter = adapter self.thread_local = threading.local() @@ -76,22 +76,28 @@ def __init__( self.json_encoder = json_encoder - @property + @cached_property + def _credentials(self) -> Credentials: + if self._provided_credentials is not None: + return self._provided_credentials + + credentials = service_account.Credentials.from_service_account_file( + self._service_account_file, + scopes=["https://www.googleapis.com/auth/firebase.messaging"], + ) + # Service account credentials has project_id (others are not) + self._project_id = credentials.project_id or self._project_id + self._service_account_file = None + return credentials + + @cached_property def fcm_end_point(self) -> str: - if self._fcm_end_point is not None: - return self._fcm_end_point - if self.credentials is None: - self._initialize_credentials() - # prefer the project ID scoped to the supplied credentials. - # If, for some reason, the credentials do not specify a project id, - # we'll check for an explicitly supplied one, and raise an error otherwise - project_id = getattr(self.credentials, "project_id", None) or self._project_id - if not project_id: - raise AuthenticationError( - "Please provide a project_id either explicitly or through Google credentials." - ) - self._fcm_end_point = self.FCM_END_POINT_BASE + f"/{project_id}/messages:send" - return self._fcm_end_point + if self._provided_credentials is None: + # read credentails to resolve project_id if needed + _ = self._credentials + if self._project_id is None: + raise RuntimeError("Please provide a project_id either explicitly or through Google credentials.") + return self.FCM_END_POINT_BASE + f"/{self._project_id}/messages:send" @property def requests_session(self): @@ -171,32 +177,18 @@ def _is_access_token_expired(self, response): return False - def _initialize_credentials(self): - """ - Initialize credentials and FCM endpoint if not already initialized. - """ - if self.credentials is None: - self.credentials = service_account.Credentials.from_service_account_file( - self._service_account_file, - scopes=["https://www.googleapis.com/auth/firebase.messaging"], - ) - self._service_account_file = None - - def _get_access_token(self): + def _get_access_token(self) -> str: """ Generates access token from credentials. If token expires then new access token is generated. Returns: str: Access token """ - if self.credentials is None: - self._initialize_credentials() - # get OAuth 2.0 access token try: request = google.auth.transport.requests.Request() - self.credentials.refresh(request) - return self.credentials.token + self._credentials.refresh(request) + return self._credentials.token # pyright: ignore[reportReturnType] except Exception as e: raise InvalidDataError(e) diff --git a/tests/conftest.py b/tests/conftest.py index 9b0379d..668c86e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,24 +2,20 @@ from unittest.mock import AsyncMock import pytest +from google.auth.credentials import Credentials from pyfcm import FCMNotification from pyfcm.baseapi import BaseAPI -from google.auth.credentials import Credentials class DummyCredentials(Credentials): - def refresh(): + def refresh(self, request): pass - @property - def project_id(self): - return "test" - -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def push_service(): - return FCMNotification(credentials=DummyCredentials()) + return FCMNotification(credentials=DummyCredentials(), project_id="test") @pytest.fixture @@ -48,6 +44,6 @@ def mock_aiohttp_session(mocker): return mock_send -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def base_api(): - return BaseAPI(credentials=DummyCredentials()) + return BaseAPI(credentials=DummyCredentials(), project_id="test") diff --git a/tests/test_baseapi.py b/tests/test_baseapi.py index fb4a2dd..c4ddc39 100644 --- a/tests/test_baseapi.py +++ b/tests/test_baseapi.py @@ -1,6 +1,15 @@ import json import time +import pytest + + +def test_empty_project_id(base_api): + base_api._project_id = None + with pytest.raises(RuntimeError) as e: + base_api.fcm_end_point + assert str(e.value) == "Please provide a project_id either explicitly or through Google credentials." + def test_json_dumps(base_api): json_string = base_api.json_dumps([{"test": "Test"}, {"test2": "Test2"}]) diff --git a/tests/test_fcm.py b/tests/test_fcm.py index 1c1284d..d1832fe 100644 --- a/tests/test_fcm.py +++ b/tests/test_fcm.py @@ -12,9 +12,9 @@ def test_push_service_without_credentials(): def test_push_service_directly_passed_credentials(push_service): # We should infer the project ID/endpoint from credentials # without the need to explcitily pass it + push_service._project_id = "abc123" assert push_service.fcm_end_point == ( - "https://fcm.googleapis.com/v1/projects/" - f"{push_service.credentials.project_id}/messages:send" + "https://fcm.googleapis.com/v1/projects/abc123/messages:send" )