-
Notifications
You must be signed in to change notification settings - Fork 10
Push notification token management #5178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
9ac98ba
d9471c5
7edd3ef
189df46
14990e6
8fecac0
5fb287e
6f91c6f
833ec10
31da761
bcae514
8997d37
24709f4
3a3033b
949f1ac
1c67610
b0222b3
2459551
086862e
669b1c9
bcae983
9508c8e
16d0f99
f484ac6
f668fe0
e51eb3e
de523dd
709be15
2664f47
2fc8548
aec4720
fcb8a26
9b30978
d9ff808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| """Push notification tokens | ||
|
|
||
| Revision ID: 7d2194c5051e | ||
| Revises: 0b46effaf3a1 | ||
| Create Date: 2026-03-02 10:48:11.523814 | ||
|
|
||
| """ | ||
| from alembic import op | ||
| import sqlalchemy as sa | ||
| from sqlalchemy.dialects import postgresql | ||
|
|
||
| from wps_shared.db.models.common import TZTimeStamp | ||
|
|
||
| # revision identifiers, used by Alembic. | ||
| revision = '7d2194c5051e' | ||
| down_revision = '0b46effaf3a1' | ||
| branch_labels = None | ||
| depends_on = None | ||
|
|
||
|
|
||
| def upgrade(): | ||
| op.create_table('device_token', | ||
| sa.Column('id', sa.Integer(), nullable=False), | ||
| sa.Column('user_id', sa.String(), nullable=True), | ||
| sa.Column('platform', sa.String(), nullable=False), | ||
| sa.Column('token', sa.String(), nullable=False), | ||
| sa.Column('is_active', sa.Boolean(), nullable=False), | ||
| sa.Column('created_at', TZTimeStamp(), nullable=False), | ||
| sa.Column('updated_at', TZTimeStamp(), nullable=False), | ||
| sa.PrimaryKeyConstraint('id'), | ||
| comment='Device token management.' | ||
| ) | ||
| op.create_index(op.f('ix_device_token_id'), 'device_token', ['id'], unique=False) | ||
| op.create_index(op.f('ix_device_token_platform'), 'device_token', ['platform'], unique=False) | ||
| op.create_index(op.f('ix_device_token_token'), 'device_token', ['token'], unique=True) | ||
|
|
||
|
|
||
| def downgrade(): | ||
| op.drop_index(op.f('ix_device_token_token'), table_name='device_token') | ||
| op.drop_index(op.f('ix_device_token_platform'), table_name='device_token') | ||
| op.drop_index(op.f('ix_device_token_id'), table_name='device_token') | ||
| op.drop_table('device_token') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
|
|
||
| import asyncio | ||
| from typing import List | ||
|
|
||
| from firebase_admin import messaging | ||
|
|
||
| from wps_shared.db.crud.fcm import deactivate_device_tokens | ||
| from wps_shared.db.database import get_async_write_session_scope | ||
| from wps_shared.db.models.fcm import DeviceToken | ||
| from wps_shared.utils.time import get_utc_now | ||
|
|
||
| # Simple exponential backoff with jitter for transient quota/server issues | ||
| async def _retry_send_multicast(multicast_msg: messaging.MulticastMessage, | ||
| max_retries: int = 5, | ||
| base_delay: float = 0.5): | ||
| attempt = 0 | ||
| while True: | ||
| try: | ||
| return messaging.send_multicast(multicast_msg, dry_run=False) | ||
| except Exception: | ||
| # Retry on probable transient conditions: quota (429), backend unavailable, etc. | ||
| # You can inspect e to match known transient cases in your logs. | ||
| attempt += 1 | ||
| if attempt > max_retries: | ||
| raise | ||
| # Exponential backoff with jitter | ||
| delay = (base_delay * (2 ** (attempt - 1))) + (0.1 * attempt) | ||
| await asyncio.sleep(delay) | ||
|
|
||
|
|
||
| async def deactivate_bad_tokens(db, tokens: List[str], responses): | ||
| """ | ||
| Deactivate tokens that failed with terminal errors like 'UNREGISTERED'. | ||
| """ | ||
| # For MulticastResponse: | ||
| # responses.responses[i].exception may contain details; many backends surface 'UNREGISTERED' | ||
| # when a token is invalid/stale. Remove/deactivate those. | ||
| stale_tokens: List[str] = [] | ||
| for idx, resp in enumerate(responses.responses): | ||
| if not resp.success: | ||
| exc = resp.exception | ||
| if exc and hasattr(exc, "code"): | ||
| code = getattr(exc, "code", None) | ||
| # TODO: Potentially expand this list based on observed error codes. | ||
| if str(code).upper() in {"UNREGISTERED"}: | ||
| stale_tokens.append(tokens[idx]) | ||
| token = tokens[idx] | ||
| db.query(DeviceToken).filter(DeviceToken.token == token).update( | ||
| {"is_active": False, "updated_at": get_utc_now()} | ||
| ) | ||
| async with get_async_write_session_scope() as session: | ||
| await deactivate_device_tokens(session, stale_tokens) | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from typing import Optional | ||
|
|
||
| from pydantic import BaseModel, Field | ||
|
|
||
|
|
||
| class RegisterDeviceRequest(BaseModel): | ||
| user_id: Optional[str] = None | ||
| token: str = Field(..., min_length=10) | ||
| platform: Optional[str] = Field(..., pattern="^(ios|android)?$") | ||
|
|
||
| class UnregisterDeviceRequest(BaseModel): | ||
| token: str | ||
conbrad marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| class DeviceRequestResponse(BaseModel): | ||
| success: bool | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from fastapi import APIRouter | ||
| from wps_shared.db.crud.fcm import ( | ||
| get_device_by_token, | ||
| save_device_token, | ||
| update_device_token_is_active, | ||
| ) | ||
| from wps_shared.db.database import get_async_write_session_scope | ||
| from wps_shared.db.models.fcm import DeviceToken | ||
| from wps_shared.utils.time import get_utc_now | ||
|
|
||
| from app.fcm.schema import DeviceRequestResponse, RegisterDeviceRequest, UnregisterDeviceRequest | ||
|
|
||
| router = APIRouter(prefix="/device") | ||
|
|
||
|
|
||
| @router.post("/register") | ||
conbrad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def register_device(request: RegisterDeviceRequest): | ||
conbrad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Upsert a device token for a user. Called this at app start and whenever FCM token refreshes. | ||
| """ | ||
| async with get_async_write_session_scope() as session: | ||
| existing = await get_device_by_token(session, request.token) | ||
| if existing: | ||
| existing.is_active = True | ||
| existing.token = request.token | ||
| existing.updated_at = get_utc_now() | ||
conbrad marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+32
to
+35
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just thinking about work phones, and on the assumption that the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After doing a bit of reading I don't think we need to do this. If I understand Apple's page properly there would be a new device id if the phone was wiped and given to someone else |
||
| else: | ||
| device_token = DeviceToken( | ||
| user_id=request.user_id, | ||
| token=request.token, | ||
| platform=request.platform, | ||
| is_active=True, | ||
| ) | ||
| save_device_token(session, device_token) | ||
| return DeviceRequestResponse(success=True) | ||
|
|
||
|
|
||
| @router.delete("/unregister") | ||
conbrad marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| async def unregister_device(request: UnregisterDeviceRequest): | ||
| """ | ||
| Mark a token inactive (e.g., user logged out or uninstalled). | ||
| """ | ||
| async with get_async_write_session_scope() as session: | ||
| await update_device_token_is_active(session, request.token) | ||
conbrad marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return DeviceRequestResponse(success=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| """ Unit tests for FCM endpoints. | ||
| """ | ||
| from starlette.testclient import TestClient | ||
| import app.main | ||
| from unittest.mock import patch | ||
| from datetime import datetime | ||
|
|
||
|
|
||
| def test_register_device_success(): | ||
| """Test that device registration returns 200/OK.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| # Test data | ||
| request_data = { | ||
| "user_id": "test-user-123", | ||
| "token": "test-fcm-token-456", | ||
| "platform": "android" | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
| with patch('app.routers.fcm.get_device_by_token', return_value=None), \ | ||
| patch('app.routers.fcm.save_device_token'): | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["success"] == True | ||
| assert response.headers["content-type"] == "application/json" | ||
|
|
||
|
|
||
| def test_register_device_already_exists(): | ||
| """Test that existing device registration updates successfully.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = { | ||
| "user_id": "test-user-123", | ||
| "token": "existing-fcm-token", | ||
|
||
| "platform": "ios" | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
|
|
||
| existing_device = type('', (object,), { | ||
| 'is_active': False, | ||
| 'token': 'existing-fcm-token', | ||
|
||
| 'updated_at': datetime(2026, 1, 1) | ||
| })() | ||
|
|
||
| with patch('app.routers.fcm.get_device_by_token', return_value=existing_device), \ | ||
| patch('app.routers.fcm.save_device_token'): | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["success"] == True | ||
| assert existing_device.is_active == True # Should be updated | ||
|
|
||
|
|
||
| def test_register_device_missing_fields(): | ||
| """Test that missing fields in registration request returns 422.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| # Missing 'token' field which is required | ||
| request_data = { | ||
| "user_id": "test-user-123", | ||
| "platform": "android" | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 422 | ||
|
|
||
|
|
||
| def test_register_device_invalid_platform(): | ||
| """Test that invalid platform returns 422.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = { | ||
| "user_id": "test-user-123", | ||
| "token": "test-fcm-token", | ||
| "platform": "invalid-platform", | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 422 | ||
|
|
||
|
|
||
| def test_register_device_short_token(): | ||
| """Test that short token returns 422.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = { | ||
| "user_id": "test-user-123", | ||
| "token": "short", # Less than 10 characters | ||
| "platform": "android", | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 422 | ||
|
|
||
|
|
||
| def test_unregister_device_success(): | ||
| """Test that device unregistration returns 200/OK.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = { | ||
| "token": "test-fcm-token-456" | ||
|
||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
| with patch('app.routers.fcm.update_device_token_is_active'): | ||
|
|
||
| response = client.request("DELETE", "/api/device/unregister", json=request_data) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["success"] == True | ||
|
|
||
|
|
||
| def test_unregister_device_missing_token(): | ||
| """Test that missing token field returns 422.""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = {} | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session_scope.return_value.__aenter__.return_value | ||
|
|
||
| response = client.request("DELETE", "/api/device/unregister", json=request_data) | ||
|
|
||
| assert response.status_code == 422 | ||
|
|
||
|
|
||
| def test_register_device_without_user_id(): | ||
| """Test that device registration without user_id is allowed (null user).""" | ||
| client = TestClient(app.main.app) | ||
|
|
||
| request_data = { | ||
| "token": "test-fcm-token-789", | ||
| "platform": "android", | ||
| } | ||
|
|
||
| with patch('app.routers.fcm.get_async_write_session_scope') as mock_session_scope: | ||
| mock_session = mock_session_scope.return_value.__aenter__.return_value | ||
| with patch('app.routers.fcm.get_device_by_token', return_value=None), \ | ||
| patch('app.routers.fcm.save_device_token'): | ||
|
|
||
| response = client.post("/api/device/register", json=request_data) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["success"] == True | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from sqlalchemy import select, update | ||
| from sqlalchemy.ext.asyncio import AsyncSession | ||
|
|
||
| from wps_shared.db.models.fcm import DeviceToken | ||
| from wps_shared.utils.time import get_utc_now | ||
|
|
||
|
|
||
| def save_device_token(session: AsyncSession, device_token: DeviceToken): | ||
conbrad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Add a new DeviceToken for tracking devices registered for push notifications. | ||
| :param session: An async database session. | ||
| :param device_token: The record to be saved. | ||
| :type device_token: DeviceToken | ||
| """ | ||
| session.add(device_token) | ||
brettedw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| async def get_device_by_token(session: AsyncSession, token: str): | ||
| """ | ||
| Lookup a DeviceToken by token value. | ||
|
|
||
| :param session: An async database session | ||
| :param token: A token for a registered device. | ||
| :return: A DeviceToken object or None. | ||
| """ | ||
| return await session.scalar(select(DeviceToken).where(DeviceToken.token == token)) | ||
|
|
||
|
|
||
| async def update_device_token_is_active(session: AsyncSession, token: str): | ||
| device_token = await session.scalar(select(DeviceToken).where(DeviceToken.token == token)) | ||
| if not device_token: | ||
| raise ValueError(f"DeviceToken with token {token} does not exist.") | ||
| device_token.is_active = False | ||
| device_token.updated_at = get_utc_now() | ||
|
|
||
|
|
||
| async def deactivate_device_tokens(session: AsyncSession, tokens: list[str]) -> int: | ||
conbrad marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if not tokens: | ||
| return 0 | ||
|
|
||
| stmt = ( | ||
| update(DeviceToken) | ||
| .where(DeviceToken.token.in_(tokens)) | ||
| .values( | ||
| is_active=False, | ||
| updated_at=get_utc_now(), | ||
| ) | ||
| # No need to synchronize the session: set-based UPDATE + no ORM objects loaded. | ||
| .execution_options(synchronize_session=False) | ||
| ) | ||
| result = await session.execute(stmt) | ||
|
|
||
| return result.rowcount or 0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: downgrade also should drop the
platformenumI think