Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions pyfcm/baseapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# from __future__ import annotations

from functools import cached_property
import json
import time
import threading
Expand All @@ -9,7 +10,7 @@
from urllib3 import Retry

from google.oauth2 import service_account
from google.oauth2.credentials import Credentials
from google.auth.credentials import Credentials
import google.auth.transport.requests

from pyfcm.errors import (
Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
Attributes:
service_account_file (str): path to service account JSON file
project_id (str): project ID of Google account
credentials (Credentials): Google oauth2 credentials instance, such as ADC
credentials (Credentials): Google auth credentials instance, such as ADC, service account one
proxy_dict (dict): proxy settings dictionary, use proxy (keys: `http`, `https`)
env (dict): environment settings dictionary, for example "app_engine"
json_encoder (BaseJSONEncoder): JSON encoder
Expand All @@ -53,9 +54,8 @@ def __init__(
)

self._service_account_file = service_account_file
self._fcm_end_point = None
self._project_id = project_id
self.credentials = credentials
self._provided_credentials = credentials
self.custom_adapter = adapter
self.thread_local = threading.local()

Expand All @@ -76,22 +76,28 @@ def __init__(

self.json_encoder = json_encoder

@property
@cached_property
def _credentials(self) -> Credentials:
if self._provided_credentials is not None:
return self._provided_credentials

credentials = service_account.Credentials.from_service_account_file(
self._service_account_file,
scopes=["https://www.googleapis.com/auth/firebase.messaging"],
)
# Service account credentials has project_id (others are not)
self._project_id = credentials.project_id or self._project_id
self._service_account_file = None
return credentials

@cached_property
def fcm_end_point(self) -> str:
if self._fcm_end_point is not None:
return self._fcm_end_point
if self.credentials is None:
self._initialize_credentials()
# prefer the project ID scoped to the supplied credentials.
# If, for some reason, the credentials do not specify a project id,
# we'll check for an explicitly supplied one, and raise an error otherwise
project_id = getattr(self.credentials, "project_id", None) or self._project_id
if not project_id:
raise AuthenticationError(
"Please provide a project_id either explicitly or through Google credentials."
)
self._fcm_end_point = self.FCM_END_POINT_BASE + f"/{project_id}/messages:send"
return self._fcm_end_point
if self._provided_credentials is None:
# read credentails to resolve project_id if needed
_ = self._credentials
if self._project_id is None:
raise RuntimeError("Please provide a project_id either explicitly or through Google credentials.")
return self.FCM_END_POINT_BASE + f"/{self._project_id}/messages:send"

@property
def requests_session(self):
Expand Down Expand Up @@ -171,32 +177,18 @@ def _is_access_token_expired(self, response):

return False

def _initialize_credentials(self):
"""
Initialize credentials and FCM endpoint if not already initialized.
"""
if self.credentials is None:
self.credentials = service_account.Credentials.from_service_account_file(
self._service_account_file,
scopes=["https://www.googleapis.com/auth/firebase.messaging"],
)
self._service_account_file = None

def _get_access_token(self):
def _get_access_token(self) -> str:
"""
Generates access token from credentials.
If token expires then new access token is generated.
Returns:
str: Access token
"""
if self.credentials is None:
self._initialize_credentials()

# get OAuth 2.0 access token
try:
request = google.auth.transport.requests.Request()
self.credentials.refresh(request)
return self.credentials.token
self._credentials.refresh(request)
return self._credentials.token # pyright: ignore[reportReturnType]
except Exception as e:
raise InvalidDataError(e)

Expand Down
16 changes: 6 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
from unittest.mock import AsyncMock

import pytest
from google.auth.credentials import Credentials

from pyfcm import FCMNotification
from pyfcm.baseapi import BaseAPI
from google.auth.credentials import Credentials


class DummyCredentials(Credentials):
def refresh():
def refresh(self, request):
pass

@property
def project_id(self):
return "test"


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def push_service():
return FCMNotification(credentials=DummyCredentials())
return FCMNotification(credentials=DummyCredentials(), project_id="test")


@pytest.fixture
Expand Down Expand Up @@ -48,6 +44,6 @@ def mock_aiohttp_session(mocker):
return mock_send


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def base_api():
return BaseAPI(credentials=DummyCredentials())
return BaseAPI(credentials=DummyCredentials(), project_id="test")
9 changes: 9 additions & 0 deletions tests/test_baseapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import json
import time

import pytest


def test_empty_project_id(base_api):
base_api._project_id = None
with pytest.raises(RuntimeError) as e:
base_api.fcm_end_point
assert str(e.value) == "Please provide a project_id either explicitly or through Google credentials."


def test_json_dumps(base_api):
json_string = base_api.json_dumps([{"test": "Test"}, {"test2": "Test2"}])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def test_push_service_without_credentials():
def test_push_service_directly_passed_credentials(push_service):
# We should infer the project ID/endpoint from credentials
# without the need to explcitily pass it
push_service._project_id = "abc123"
assert push_service.fcm_end_point == (
"https://fcm.googleapis.com/v1/projects/"
f"{push_service.credentials.project_id}/messages:send"
"https://fcm.googleapis.com/v1/projects/abc123/messages:send"
)


Expand Down