Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 17 additions & 64 deletions pyfcm/baseapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# from __future__ import annotations

from functools import cached_property
import json
import time
import threading
Expand All @@ -8,9 +9,7 @@
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from google.oauth2 import service_account
from google.oauth2.credentials import Credentials
import google.auth.transport.requests
from google.auth.credentials import Credentials

from pyfcm.errors import (
AuthenticationError,
Expand All @@ -19,6 +18,7 @@
FCMServerError,
FCMNotRegisteredError,
)
from pyfcm.token_manager import TokenManager

# Migration to v1 - https://firebase.google.com/docs/cloud-messaging/migrate-v1

Expand All @@ -41,21 +41,17 @@ def __init__(
Attributes:
service_account_file (str): path to service account JSON file
project_id (str): project ID of Google account
credentials (Credentials): Google oauth2 credentials instance, such as ADC
credentials (Credentials): Google auth credentials instance, such as ADC, service account one
proxy_dict (dict): proxy settings dictionary, use proxy (keys: `http`, `https`)
env (dict): environment settings dictionary, for example "app_engine"
json_encoder (BaseJSONEncoder): JSON encoder
adapter (BaseAdapter): adapter instance
"""
if not (service_account_file or credentials):
raise AuthenticationError(
"Please provide a service account file path or credentials in the constructor"
)

self._service_account_file = service_account_file
self._fcm_end_point = None
self._project_id = project_id
self.credentials = credentials
self.token_manager = TokenManager(
service_account_file=service_account_file,
project_id=project_id,
credentials=credentials,
)
self.custom_adapter = adapter
self.thread_local = threading.local()

Expand All @@ -76,22 +72,11 @@ def __init__(

self.json_encoder = json_encoder

@property
@cached_property
def fcm_end_point(self) -> str:
if self._fcm_end_point is not None:
return self._fcm_end_point
if self.credentials is None:
self._initialize_credentials()
# prefer the project ID scoped to the supplied credentials.
# If, for some reason, the credentials do not specify a project id,
# we'll check for an explicitly supplied one, and raise an error otherwise
project_id = getattr(self.credentials, "project_id", None) or self._project_id
if not project_id:
raise AuthenticationError(
"Please provide a project_id either explicitly or through Google credentials."
)
self._fcm_end_point = self.FCM_END_POINT_BASE + f"/{project_id}/messages:send"
return self._fcm_end_point
return (
self.FCM_END_POINT_BASE + f"/{self.token_manager.project_id}/messages:send"
)

@property
def requests_session(self):
Expand All @@ -105,12 +90,9 @@ def requests_session(self):
self.thread_local.requests_session = requests.Session()
self.thread_local.requests_session.mount("http://", adapter)
self.thread_local.requests_session.mount("https://", adapter)
self.thread_local.token_expiry = 0

current_timestamp = time.time()
if self.thread_local.token_expiry < current_timestamp:
self.thread_local.requests_session.headers.update(self.request_headers())
self.thread_local.token_expiry = current_timestamp + 1800
# Always update headers with current shared token
self.thread_local.requests_session.headers.update(self.request_headers())
return self.thread_local.requests_session

def send_request(self, payload=None, timeout=None):
Expand All @@ -126,7 +108,7 @@ def send_request(self, payload=None, timeout=None):
return self.send_request(payload, timeout)

if self._is_access_token_expired(response):
self.thread_local.token_expiry = 0
self.token_manager.refresh_token_if_expired()
return self.send_request(payload, timeout)

return response
Expand Down Expand Up @@ -171,35 +153,6 @@ def _is_access_token_expired(self, response):

return False

def _initialize_credentials(self):
"""
Initialize credentials and FCM endpoint if not already initialized.
"""
if self.credentials is None:
self.credentials = service_account.Credentials.from_service_account_file(
self._service_account_file,
scopes=["https://www.googleapis.com/auth/firebase.messaging"],
)
self._service_account_file = None

def _get_access_token(self):
"""
Generates access token from credentials.
If token expires then new access token is generated.
Returns:
str: Access token
"""
if self.credentials is None:
self._initialize_credentials()

# get OAuth 2.0 access token
try:
request = google.auth.transport.requests.Request()
self.credentials.refresh(request)
return self.credentials.token
except Exception as e:
raise InvalidDataError(e)

def request_headers(self):
"""
Generates request headers including Content-Type and Authorization of Bearer token
Expand All @@ -209,7 +162,7 @@ def request_headers(self):
"""
return {
"Content-Type": "application/json",
"Authorization": "Bearer " + self._get_access_token(),
"Authorization": "Bearer " + self.token_manager.get_access_token(),
}

def json_dumps(self, data):
Expand Down
151 changes: 151 additions & 0 deletions pyfcm/token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from functools import cached_property
import threading
from datetime import datetime, timedelta, timezone
from typing import Optional

from google.oauth2 import service_account
from google.auth.credentials import Credentials
import google.auth.transport.requests

from pyfcm.errors import AuthenticationError, InvalidDataError


class TokenManager:
"""
Token management class extracted from BaseAPI.
Handles authentication credentials and access token lifecycle.
"""

def __init__(
self,
service_account_file: Optional[str] = None,
project_id: Optional[str] = None,
credentials: Optional[Credentials] = None,
):
"""
Initialize TokenManager

Args:
service_account_file (str): path to service account JSON file
project_id (str): project ID of Google account
credentials (Credentials): Google auth credentials instance
"""
if not (service_account_file or credentials):
raise AuthenticationError(
"Please provide a service account file path or credentials in the constructor"
)

self._service_account_file = service_account_file
self._project_id = project_id
self._provided_credentials = credentials

# Shared token management across threads
self._shared_token = None
self._token_lock = threading.RLock()

@cached_property
def _credentials(self) -> Credentials:
"""
Get authentication credentials

Returns:
Credentials: Google authentication credentials
"""
if self._provided_credentials is not None:
return self._provided_credentials

credentials = service_account.Credentials.from_service_account_file(
self._service_account_file,
scopes=["https://www.googleapis.com/auth/firebase.messaging"],
)
# Service account credentials has project_id (others are not)
self._project_id = credentials.project_id or self._project_id
self._service_account_file = None
return credentials

@cached_property
def project_id(self) -> str:
"""
Get project ID

Returns:
str: Project ID

Raises:
RuntimeError: If project_id is not configured
"""
# Read credentials to resolve project_id if needed
_ = self._credentials
if self._project_id is None:
raise RuntimeError(
"Please provide a project_id either explicitly or through Google credentials."
)
return self._project_id

def _is_token_valid(self) -> bool:
"""
Enhanced token validity check with fallback mechanisms.
Combines expired property check with time-based validation.

Returns:
bool: True if token is valid, False otherwise
"""
if not self._shared_token:
return False

if self._credentials.expired:
return False

# Fallback check: time-based validation with 5-minute buffer
# This accounts for the 4-minute early expiration issue
if (
hasattr(self._credentials, "expiry")
and self._credentials.expiry
and self._credentials.expiry
<= datetime.now(timezone.utc) + timedelta(minutes=5)
):
return False

return True

def get_access_token(self) -> str:
"""
Thread-safe access token management with shared token across threads.
Uses double-checked locking pattern for performance with enhanced validation.

Returns:
str: Access token

Raises:
InvalidDataError: If token acquisition fails
"""
# First check without lock (performance optimization)
if self._is_token_valid():
return self._shared_token

# Acquire lock and check again (double-checked locking)
with self._token_lock:
if self._is_token_valid():
return self._shared_token

try:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
self._shared_token = self._credentials.token
return self._shared_token
except Exception as e:
raise InvalidDataError(e)

def refresh_token_if_expired(self) -> None:
"""
Refresh token if needed
"""
with self._token_lock:
self._shared_token = None
if self._credentials:
try:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
except Exception:
# If refresh fails, let the next request handle it
pass
23 changes: 14 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
from unittest.mock import AsyncMock

import pytest
from google.auth.credentials import Credentials

from pyfcm import FCMNotification
from pyfcm.baseapi import BaseAPI
from google.auth.credentials import Credentials


class DummyCredentials(Credentials):
def refresh():
pass
def __init__(self):
self.token = "dummy_token"
self._expired = True

def refresh(self, request):
self.token = "refreshed_dummy_token"
self._expired = False

@property
def project_id(self):
return "test"
def expired(self):
return self._expired


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def push_service():
return FCMNotification(credentials=DummyCredentials())
return FCMNotification(credentials=DummyCredentials(), project_id="test")


@pytest.fixture
Expand Down Expand Up @@ -48,6 +53,6 @@ def mock_aiohttp_session(mocker):
return mock_send


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def base_api():
return BaseAPI(credentials=DummyCredentials())
return BaseAPI(credentials=DummyCredentials(), project_id="test")
20 changes: 16 additions & 4 deletions tests/test_baseapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import json
import time

import pytest


def test_empty_project_id(base_api):
base_api.token_manager._project_id = None
with pytest.raises(RuntimeError) as e:
base_api.fcm_end_point
assert (
str(e.value)
== "Please provide a project_id either explicitly or through Google credentials."
)


def test_json_dumps(base_api):
Expand Down Expand Up @@ -46,7 +57,6 @@ def test_send_request_normal(base_api, mocker):

base_api.thread_local = mocker.Mock()
base_api.thread_local.requests_session = mock_session
base_api.thread_local.token_expiry = time.time() + 1000

# do
result = base_api.send_request(payload="test_payload", timeout=30)
Expand All @@ -73,7 +83,6 @@ def test_send_request_retry_after(base_api, mocker):

base_api.thread_local = mocker.Mock()
base_api.thread_local.requests_session = mock_session
base_api.thread_local.token_expiry = time.time() + 1000

# do
result = base_api.send_request(payload="test_payload", timeout=30)
Expand Down Expand Up @@ -118,11 +127,14 @@ def test_send_request_access_token_expired_retry(base_api, mocker):
type(base_api), "requests_session", new_callable=mocker.PropertyMock
)
mock_requests_session.return_value = mock_session
base_api.token_manager._shared_token = "dummy"
assert base_api.token_manager._shared_token is not None

# do
result = base_api.send_request(payload="test_payload", timeout=30)

# check
assert mock_session.post.call_count == 2
assert base_api.thread_local.token_expiry == 0
# token cleared, but not refreshed because request_session is mocked
assert base_api.token_manager._shared_token is None
assert result == success_response
Loading