diff --git a/genesis_notification/common/constants.py b/genesis_notification/common/constants.py index c32d3f0..072cc26 100644 --- a/genesis_notification/common/constants.py +++ b/genesis_notification/common/constants.py @@ -29,3 +29,9 @@ class EventStatus(str, enum.Enum): IN_PROGRESS = "IN_PROGRESS" ACTIVE = "ACTIVE" ERROR = "ERROR" + + +class PushDeliveryStatus(str, enum.Enum): + SUCCESS = "SUCCESS" + PERMANENT_FAILURE = "PERMANENT_FAILURE" + RETRYABLE_FAILURE = "RETRYABLE_FAILURE" diff --git a/genesis_notification/dm/models.py b/genesis_notification/dm/models.py index c8fa255..99ad133 100644 --- a/genesis_notification/dm/models.py +++ b/genesis_notification/dm/models.py @@ -15,7 +15,9 @@ # under the License. import datetime +import json import logging +from dataclasses import dataclass from email.mime import text from email.mime import multipart import smtplib @@ -30,8 +32,16 @@ from restalchemy.storage.sql import orm import zulip -from genesis_notification.common import constants as c +import firebase_admin +from firebase_admin import credentials, messaging +from firebase_admin.messaging import ( + UnregisteredError, + SenderIdMismatchError, + ThirdPartyAuthError, +) +from genesis_notification.common import constants as c +from genesis_notification.common.constants import PushDeliveryStatus LOG = logging.getLogger(__name__) @@ -56,6 +66,253 @@ class ModelWithAlwaysActiveStatus(models.Model): ) +class Installation( + models.ModelWithUUID, + ModelWithAlwaysActiveStatus, + models.ModelWithTimestamp, + orm.SQLStorableMixin, +): + __tablename__ = "installations" + + installation_id = properties.property( + types.String(min_length=8, max_length=128), + required=True, + ) + + user_id = properties.property( + types.UUID(), + required=True, + ) + + push_token = properties.property( + types.String(min_length=16, max_length=512), + required=True, + ) + + platform = properties.property( + types.Enum(["ios", "android", "web"]), + required=True, + ) + + app_version = properties.property( + types.String(max_length=16), + required=True, + ) + + os_version = properties.property( + types.String(max_length=16), + required=True, + ) + + device_model = properties.property( + types.String(max_length=16), + required=True, + ) + + +FCM_PERMANENT_ERRORS = { + "UNREGISTERED", + "INVALID_ARGUMENT", + "NOT_FOUND", +} + +FCM_RETRYABLE_ERRORS = { + "UNAVAILABLE", + "INTERNAL", + "QUOTA_EXCEEDED", +} + + +@dataclass +class PushDeliveryResult: + + installation_id: str + token: str + + status: PushDeliveryStatus + + error_code: str | None = None + error_message: str | None = None + + provider_response: dict | None = None + + +@dataclass +class PushBatchResult: + + results: list[PushDeliveryResult] + + def success_count(self): + return sum(1 for r in self.results if r.status == PushDeliveryStatus.SUCCESS) + + def permanent_failures(self): + return [ + r for r in self.results + if r.status == PushDeliveryStatus.PERMANENT_FAILURE + ] + + def retryable_failures(self): + return [ + r for r in self.results + if r.status == PushDeliveryStatus.RETRYABLE_FAILURE + ] + + def total_failure(self): + return self.success_count() == 0 + + +class FCMProtocol(types_dynamic.AbstractKindModel): + KIND = "fcm" + + project_id = properties.property( + types.String(), + required=True, + ) + + service_account_json = properties.property( + types.String(), + required=True, + ) + + def _get_firebase_app(self): + + service_account_info = json.loads(self.service_account_json) + + cred = credentials.Certificate(service_account_info) + + app_name = f"fcm-{self.project_id}" + + try: + app = firebase_admin.get_app(app_name) + except ValueError: + app = firebase_admin.initialize_app( + cred, + name=app_name, + ) + + return app + + def _map_exception(self, exc): + + if isinstance(exc, ( + UnregisteredError, + SenderIdMismatchError, + )): + return PushDeliveryStatus.PERMANENT_FAILURE + + if isinstance(exc, ( + ThirdPartyAuthError, + )): + return PushDeliveryStatus.RETRYABLE_FAILURE + + return PushDeliveryStatus.RETRYABLE_FAILURE + + def _send_batch(self, installations, content): + + app = self._get_firebase_app() + + tokens = [i.push_token for i in installations] + installation_map = { + i.push_token: i.installation_id + for i in installations + } + + results = [] + + for i in range(0, len(tokens), 500): + + chunk = tokens[i:i + 500] + + message = messaging.MulticastMessage( + tokens=chunk, + notification=messaging.Notification( + title=content.title, + body=content.body, + ), + data=content.data or {}, + ) + + # In firebase-admin 6.x+, send_multicast was replaced with send_each_for_multicast + response = messaging.send_each_for_multicast(message, app=app) + + for idx, resp in enumerate(response.responses): + + token = chunk[idx] + installation_id = installation_map[token] + + if resp.success: + + results.append( + PushDeliveryResult( + installation_id=installation_id, + token=token, + status=PushDeliveryStatus.SUCCESS, + provider_response={"message_id": resp.message_id}, + ) + ) + + else: + status = self._map_exception(resp.exception) + + results.append( + PushDeliveryResult( + installation_id=installation_id, + token=token, + status=status, + error_code=type(resp.exception).__name__, + error_message=str(resp.exception), + ) + ) + + return PushBatchResult(results) + + def _build_payload(self, token, content): + + return { + "message": { + "token": token, + "notification": { + "title": content.title, + "body": content.body, + }, + "data": content.data or {}, + } + } + + def _process_batch_result(self, batch_result): + + for r in batch_result.permanent_failures(): + + inst = Installation.objects.get_one( + filters={"installation_id": r.installation_id} + ) + + if inst: + inst.status = c.AlwaysActiveStatus.INACTIVE.value + inst.save() + + def send(self, content, user_context): + + user_id = user_context["user"]["uuid"] + + installations = Installation.objects.get_all( + filters={ + "user_id": filters.EQ(user_id), + "status": filters.EQ(c.AlwaysActiveStatus.ACTIVE.value), + } + ) + + if not installations: + return + + batch_result = self._send_batch(installations, content) + + self._process_batch_result(batch_result) + + if batch_result.total_failure(): + raise RuntimeError("Push delivery totally failed") + + class SimpleSmtpProtocol(types_dynamic.AbstractKindModel): KIND = "SimpleSMTP" @@ -172,6 +429,7 @@ class Provider( types_dynamic.KindModelType(SimpleSmtpProtocol), types_dynamic.KindModelType(StartTlsSmtpProtocol), types_dynamic.KindModelType(ZulipProtocol), + types_dynamic.KindModelType(FCMProtocol), ), required=True, ) @@ -311,6 +569,28 @@ def render(self, params): } +class RenderedPushContent(AbstractContent): + KIND = "rendered_push" + + title = properties.property(types.String(), default="{{ title }}") + body = properties.property(types.String(), default="{{ body }}") + data = properties.property(types.Dict(), default=dict) + + +class PushContent(RenderedPushContent): + KIND = "push" + + def render(self, params): + return RenderedPushContent( + title=jinja2.Template(self.title).render(**params), + body=jinja2.Template(self.body).render(**params), + data={ + k: jinja2.Template(v).render(**params) if isinstance(v, str) else v + for k, v in self.data.items() + }, + ) + + class Template( models.ModelWithUUID, models.ModelWithRequiredNameDesc, @@ -326,6 +606,7 @@ class Template( types_dynamic.KindModelType(EmailContent), types_dynamic.KindModelType(ZulipStreamMessageContent), types_dynamic.KindModelType(ZulipDirectMessageContent), + types_dynamic.KindModelType(PushContent), ), required=True, ) @@ -520,6 +801,7 @@ class RenderedEvent( types_dynamic.KindModelType(RenderedEmailContent), types_dynamic.KindModelType(RenderedStreamMessageContent), types_dynamic.KindModelType(RenderedDirectMessageContent), + types_dynamic.KindModelType(RenderedPushContent), ), required=True, ) diff --git a/genesis_notification/user_api/api/controllers.py b/genesis_notification/user_api/api/controllers.py index b87eb8e..e95f598 100644 --- a/genesis_notification/user_api/api/controllers.py +++ b/genesis_notification/user_api/api/controllers.py @@ -13,10 +13,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import datetime from restalchemy.api import controllers as ra_controllers from restalchemy.api import resources +from genesis_notification.common import constants as c from genesis_notification.dm import models from genesis_notification.user_api.api import versions @@ -56,3 +58,39 @@ class EventController(ra_controllers.BaseResourceController): __resource__ = resources.ResourceByRAModel( models.Event, convert_underscore=False ) + + +class InstallationController(ra_controllers.BaseResourceController): + __resource__ = resources.ResourceByRAModel( + models.Installation, + convert_underscore=False, + ) + + def _update_existing(self, existing, resource): + existing.push_token = resource["push_token"] + existing.platform = resource["platform"] + + existing.app_version = resource.get("app_version", "") + existing.os_version = resource.get("os_version", "") + existing.device_model = resource.get("device_model", "") + + existing.status = c.AlwaysActiveStatus.ACTIVE.value + existing.user_id = resource.get("user_id", "") + + existing.save() + + return existing + + def create(self, **kwargs): + installation_id = kwargs.get("installation_id") + + existing = models.Installation.objects.get_one( + filters={ + "installation_id": installation_id, + } + ) + + if existing: + return self._update_existing(existing, kwargs) + + return super().create(**kwargs) diff --git a/genesis_notification/user_api/api/routes.py b/genesis_notification/user_api/api/routes.py index f2e0370..3301368 100644 --- a/genesis_notification/user_api/api/routes.py +++ b/genesis_notification/user_api/api/routes.py @@ -35,6 +35,10 @@ class EventRoute(routes.Route): __controller__ = controllers.EventController +class InstallationRoute(routes.Route): + __controller__ = controllers.InstallationController + + class ApiEndpointRoute(routes.Route): """Handler for /v1.0/ endpoint""" @@ -45,3 +49,4 @@ class ApiEndpointRoute(routes.Route): templates = routes.route(TemplateRoute) event_types = routes.route(EventTypeRoute) events = routes.route(EventRoute) + installations = routes.route(InstallationRoute) diff --git a/migrations/0001-fcm-installation-model-232cd7.py b/migrations/0001-fcm-installation-model-232cd7.py new file mode 100644 index 0000000..738cdad --- /dev/null +++ b/migrations/0001-fcm-installation-model-232cd7.py @@ -0,0 +1,72 @@ +# Copyright 2016 Eugene Frolov +# +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from restalchemy.storage.sql import migrations + + +class MigrationStep(migrations.AbstractMigrationStep): + + def __init__(self): + self._depends = ["0000-init-tables-40a307.py"] + + @property + def migration_id(self): + return "232cd714-c3c7-448e-a852-c0787b20b5ef" + + @property + def is_manual(self): + return False + + def upgrade(self, session): + expressions = [ + """ + CREATE TABLE "installations" ( + "uuid" CHAR(36) PRIMARY KEY, + "status" enum_status_active NOT NULL DEFAULT 'ACTIVE', + "created_at" TIMESTAMP(6) NOT NULL DEFAULT NOW(), + "updated_at" TIMESTAMP(6) NOT NULL DEFAULT NOW(), + + "installation_id" VARCHAR(128) NOT NULL, + "user_id" CHAR(36) NOT NULL, + "platform" VARCHAR(32) NOT NULL, + "push_token" TEXT NOT NULL, + "app_version" CHAR(16) NOT NULL, + "os_version" CHAR(16) NOT NULL, + "device_model" CHAR(16) NOT NULL + ); + """, + """ + CREATE INDEX "installations_user_idx" + ON "installations" ("user_id"); + """, + """ + CREATE INDEX "installations_token_idx" + ON "installations" ("push_token"); + """, + """ + CREATE INDEX "installations_installation_id_idx" + ON "installations" ("installation_id"); + """, + ] + + for expression in expressions: + session.execute(expression) + + def downgrade(self, session): + self._delete_table_if_exists(session, "installations") + + +migration_step = MigrationStep() diff --git a/requirements.txt b/requirements.txt index 9904182..17dfc41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ pbr>=1.10.0,<=5.8.1 # Apache-2.0 oslo.config>=3.22.2,<10.0.0 # Apache-2.0 -restalchemy>=14.1.0,<15.0.0 # Apache-2.0 +restalchemy>=15.0.0,<16.0.0 # Apache-2.0 gcl_iam>=0.8.0,<1.0.0 # Apache-2.0 gcl_looper>=0.1.0,<=1.0.0 # Apache-2.0 bjoern>=3.2.2 # BSD License (BSD-3-Clause) Jinja2>=3.1.5,<4.0.0 # BSD License (BSD-3-Clause) bazooka>=1.3.0,<2.0.0 # Apache-2.0 zulip>=0.9.0,<1.0.0 # Apache-2.0 +firebase_admin >= 7.2.0