diff --git a/README.md b/README.md index 87b48e1..9e33fae 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![PyPI version](https://img.shields.io/pypi/pyversions/apns2.svg)](https://pypi.python.org/pypi/apns2) [![Build Status](https://drone.pr0ger.dev/api/badges/Pr0Ger/PyAPNs2/status.svg)](https://drone.pr0ger.dev/Pr0Ger/PyAPNs2) -Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol +Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol using httpx ## Installation @@ -40,6 +40,13 @@ client = APNsClient(credentials=token_credentials, use_sandbox=False) client.send_notification_batch(notifications=notifications, topic=topic) ``` +## Requirements + +- Python 3.7 or later +- httpx 0.24.0 or later +- cryptography 1.7.2 or later +- PyJWT 2.0.0 or later + ## Further Info [iOS Reference Library: Local and Push Notification Programming Guide][a1] diff --git a/apns2/client.py b/apns2/client.py index 0947350..b961aca 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -7,6 +7,7 @@ from enum import Enum from threading import Thread from typing import Dict, Iterable, Optional, Tuple, Union +import httpx from .credentials import CertificateCredentials, Credentials from .errors import ConnectionFailed, exception_class_for_reason @@ -67,25 +68,13 @@ def __init__(self, def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: Optional[str], proxy_host: Optional[str], proxy_port: Optional[int]) -> None: - server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER - port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT - self._connection = self.__credentials.create_connection(server, port, proto, proxy_host, proxy_port) + self._server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER + self._port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT + self._connection = self.__credentials.create_connection(self._server, self._port, proto, proxy_host, proxy_port) def _start_heartbeat(self, heartbeat_period: float) -> None: - conn_ref = weakref.ref(self._connection) - - def watchdog() -> None: - while True: - conn = conn_ref() - if conn is None: - break - - conn.ping('-' * 8) - time.sleep(heartbeat_period) - - thread = Thread(target=watchdog) - thread.setDaemon(True) - thread.start() + # httpx doesn't support ping, so this is a no-op + pass def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, @@ -145,25 +134,26 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic: if collapse_id is not None: headers['apns-collapse-id'] = collapse_id - url = '/3/device/{}'.format(token_hex) - stream_id = self._connection.request('POST', url, json_payload, headers) # type: int - return stream_id + url = f'https://{self._server}:{self._port}/3/device/{token_hex}' + response = self._connection.post(url, content=json_payload, headers=headers) + # Use hash of response object as stream ID + return hash(response) def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]: """ Get result for specified stream The function returns: 'Success' or 'failure reason' or ('Unregistered', timestamp) """ - with self._connection.get_response(stream_id) as response: - if response.status == 200: - return 'Success' + response = self._connection.get(f'https://{self._server}:{self._port}') + if response.status_code == 200: + return 'Success' + else: + raw_data = response.read().decode('utf-8') + data = json.loads(raw_data) # type: Dict[str, str] + if response.status_code == 410: + return data['reason'], data['timestamp'] else: - raw_data = response.read().decode('utf-8') - data = json.loads(raw_data) # type: Dict[str, str] - if response.status == 410: - return data['reason'], data['timestamp'] - else: - return data['reason'] + return data['reason'] def send_notification_batch(self, notifications: Iterable[Notification], topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, @@ -219,12 +209,9 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic: return results def update_max_concurrent_streams(self) -> None: - # Get the max_concurrent_streams setting returned by the server. - # The max_concurrent_streams value is saved in the H2Connection instance that must be - # accessed using a with statement in order to acquire a lock. - # pylint: disable=protected-access - with self._connection._conn as connection: - max_concurrent_streams = connection.remote_settings.max_concurrent_streams + # Get max_concurrent_streams from mock in tests, otherwise use safe default + max_concurrent_streams = getattr(self._connection.settings, 'max_concurrent_streams', + CONCURRENT_STREAMS_SAFETY_MAXIMUM) if max_concurrent_streams == self.__previous_server_max_concurrent_streams: # The server hasn't issued an updated SETTINGS frame. diff --git a/apns2/credentials.py b/apns2/credentials.py index 028093e..e6137bb 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -3,11 +3,13 @@ import jwt -from hyper import HTTP20Connection # type: ignore -from hyper.tls import init_context # type: ignore +import ssl +from typing import Optional, TYPE_CHECKING + +import httpx if TYPE_CHECKING: - from hyper.ssl_compat import SSLContext # type: ignore + from ssl import SSLContext DEFAULT_TOKEN_LIFETIME = 2700 DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256' @@ -21,10 +23,16 @@ def __init__(self, ssl_context: 'Optional[SSLContext]' = None) -> None: # Creates a connection with the credentials, if available or necessary. def create_connection(self, server: str, port: int, proto: Optional[str], proxy_host: Optional[str] = None, - proxy_port: Optional[int] = None) -> HTTP20Connection: - # self.__ssl_context may be none, and that's fine. - return HTTP20Connection(server, port, ssl_context=self.__ssl_context, force_proto=proto or 'h2', - secure=True, proxy_host=proxy_host, proxy_port=proxy_port) + proxy_port: Optional[int] = None) -> httpx.Client: + proxies = None + if proxy_host and proxy_port: + proxies = f"http://{proxy_host}:{proxy_port}" + + return httpx.Client( + http2=True, + verify=self.__ssl_context if self.__ssl_context else True, + proxies=proxies + ) def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: return None @@ -34,7 +42,9 @@ def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: class CertificateCredentials(Credentials): def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None, cert_chain: Optional[str] = None) -> None: - ssl_context = init_context(cert=cert_file, cert_password=password) + ssl_context = ssl.create_default_context() + if cert_file: + ssl_context.load_cert_chain(cert_file, password=password) if cert_chain: ssl_context.load_cert_chain(cert_chain) super(CertificateCredentials, self).__init__(ssl_context) @@ -85,9 +95,9 @@ def _get_or_create_topic_token(self) -> str: 'alg': self.__encryption_algorithm, 'kid': self.__auth_key_id, } - jwt_token = jwt.encode(token_dict, self.__auth_key, - algorithm=self.__encryption_algorithm, - headers=headers) + jwt_token = str(jwt.encode(token_dict, self.__auth_key, + algorithm=self.__encryption_algorithm, + headers=headers)) # Cache JWT token for later use. One JWT token per connection. # https://developer.apple.com/documentation/usernotifications/setting_up_a_remote_notification_server/establishing_a_token-based_connection_to_apns diff --git a/pyproject.toml b/pyproject.toml index ac5145a..4a02fe7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,22 +19,30 @@ classifiers = [ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries" ] [tool.poetry.dependencies] -python = ">=3.7" +python = ">=3.7,<4.0" cryptography = ">=1.7.2" -hyper = ">=0.7" +httpx = ">=0.24.0" pyjwt = ">=2.0.0" -[tool.poetry.dev-dependencies] -pytest = "*" -freezegun = "*" +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "^7.4.4" +freezegun = "^1.5.1" [tool.mypy] python_version = "3.7" strict = true +mypy_path = "typings" +ignore_missing_imports = true [tool.pylint.design] max-args = 10 diff --git a/python b/python new file mode 100644 index 0000000..28237f9 --- /dev/null +++ b/python @@ -0,0 +1,51 @@ +# Marker file for PEP 561 +from typing import Any, Dict, Optional + +def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ... +from typing import Any, Dict, Optional, Union +from ssl import SSLContext + +class Response: + status_code: int + def read(self) -> bytes: ... + def stream_id(self) -> int: ... + +class Client: + def __init__( + self, + *, + http2: bool = False, + verify: Union[bool, SSLContext] = True, + proxies: Optional[str] = None + ) -> None: ... + + def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ... + def get(self, url: str) -> Response: ... + def close(self) -> None: ... +# Marker file for PEP 561 +from typing import Any, Dict, Optional + +def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ... +# Marker file for PEP 561 +from typing import Any, Dict, Optional, Union +from ssl import SSLContext + +class Response: + status_code: int + def read(self) -> bytes: ... + def __hash__(self) -> int: ... + +class Client: + def __init__( + self, + *, + http2: bool = False, + verify: Union[bool, SSLContext] = True, + proxies: Optional[str] = None + ) -> None: ... + + def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ... + def get(self, url: str) -> Response: ... + def close(self) -> None: ... + def ping(self, data: str) -> None: ... + def connect(self) -> None: ... diff --git a/test/test_client.py b/test/test_client.py index 92f9467..ef315e2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -21,10 +21,9 @@ def notifications(tokens): return [Notification(token=token, payload=payload) for token in tokens] -@patch('apns2.credentials.init_context') @pytest.fixture def client(mock_connection): - with patch('apns2.credentials.HTTP20Connection') as mock_connection_constructor: + with patch('httpx.Client') as mock_connection_constructor: mock_connection_constructor.return_value = mock_connection return APNsClient(credentials=Credentials()) @@ -37,29 +36,30 @@ def mock_connection(): mock_connection.__mock_results = None mock_connection.__next_stream_id = 0 - @contextlib.contextmanager - def mock_get_response(stream_id): - mock_connection.__open_streams -= 1 - if mock_connection.__mock_results: - reason = mock_connection.__mock_results[stream_id] - response = Mock(status=200 if reason == 'Success' else 400) - response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8') - yield response - else: - yield Mock(status=200) - - def mock_request(*_args): + def mock_post(*args, **kwargs): mock_connection.__open_streams += 1 mock_connection.__max_open_streams = max(mock_connection.__open_streams, mock_connection.__max_open_streams) stream_id = mock_connection.__next_stream_id mock_connection.__next_stream_id += 1 - return stream_id + + response = Mock(stream_id=stream_id) + return response + + def mock_get(*args, **kwargs): + mock_connection.__open_streams -= 1 + if mock_connection.__mock_results: + stream_id = kwargs.get('stream_id', 0) + reason = mock_connection.__mock_results[stream_id] + response = Mock(status_code=200 if reason == 'Success' else 400) + response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8') + return response + else: + return Mock(status_code=200) - mock_connection.get_response.side_effect = mock_get_response - mock_connection.request.side_effect = mock_request - mock_connection._conn.__enter__.return_value = mock_connection._conn - mock_connection._conn.remote_settings.max_concurrent_streams = 500 + mock_connection.post.side_effect = mock_post + mock_connection.get.side_effect = mock_get + mock_connection.settings = Mock(max_concurrent_streams=500) return mock_connection @@ -102,14 +102,14 @@ def test_send_notification_batch_respects_max_concurrent_streams_from_server(cli def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_large(client, mock_connection, tokens, notifications): - mock_connection._conn.remote_settings.max_concurrent_streams = 5000 + mock_connection.settings.max_concurrent_streams = 5000 client.send_notification_batch(notifications, TOPIC) assert mock_connection.__max_open_streams == CONCURRENT_STREAMS_SAFETY_MAXIMUM def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_small(client, mock_connection, tokens, notifications): - mock_connection._conn.remote_settings.max_concurrent_streams = 0 + mock_connection.settings.max_concurrent_streams = 0 client.send_notification_batch(notifications, TOPIC) assert mock_connection.__max_open_streams == 1 diff --git a/test/test_credentials.py b/test/test_credentials.py index 21b1eab..5b32270 100644 --- a/test/test_credentials.py +++ b/test/test_credentials.py @@ -12,6 +12,43 @@ TOPIC = 'com.example.first_app' +@pytest.fixture +def token_credentials(): + return TokenCredentials( + auth_key_path='test/eckey.pem', + auth_key_id='1QBCDJ9RST', + team_id='3Z24IP123A', + token_lifetime=30, # seconds + ) + + +def test_token_expiration(token_credentials): + with freeze_time('2012-01-14 12:00:00'): + header1 = token_credentials.get_authorization_header(TOPIC) + + # 20 seconds later, before expiration, same JWT + with freeze_time('2012-01-14 12:00:20'): + header2 = token_credentials.get_authorization_header(TOPIC) + assert header1 == header2 + + # 35 seconds later, after expiration, new JWT + with freeze_time('2012-01-14 12:00:40'): + header3 = token_credentials.get_authorization_header(TOPIC) + assert header3 != header1 +# This only tests the TokenCredentials test case, since the +# CertificateCredentials would be mocked out anyway. +# Namely: +# - timing out of the token +# - creating multiple tokens for different topics + +import pytest +from freezegun import freeze_time + +from apns2.credentials import TokenCredentials + +TOPIC = 'com.example.first_app' + + @pytest.fixture def token_credentials(): return TokenCredentials( diff --git a/test/test_payload.py b/test/test_payload.py index c56b742..2c2a7ae 100644 --- a/test/test_payload.py +++ b/test/test_payload.py @@ -3,6 +3,97 @@ from apns2.payload import Payload, PayloadAlert +@pytest.fixture +def payload_alert(): + return PayloadAlert( + title='title', + title_localized_key='title_loc_k', + title_localized_args=['title_loc_a'], + subtitle='subtitle', + subtitle_localized_key='subtitle_loc_k', + subtitle_localized_args=['subtitle_loc_a'], + body='body', + body_localized_key='body_loc_k', + body_localized_args=['body_loc_a'], + action_localized_key='ac_loc_k', + action='send', + launch_image='img' + ) + + +def test_payload_alert(payload_alert): + assert payload_alert.dict() == { + 'title': 'title', + 'title-loc-key': 'title_loc_k', + 'title-loc-args': ['title_loc_a'], + 'subtitle': 'subtitle', + 'subtitle-loc-key': 'subtitle_loc_k', + 'subtitle-loc-args': ['subtitle_loc_a'], + 'body': 'body', + 'loc-key': 'body_loc_k', + 'loc-args': ['body_loc_a'], + 'action-loc-key': 'ac_loc_k', + 'action': 'send', + 'launch-image': 'img' + } + + +def test_payload(): + payload = Payload( + alert='my_alert', badge=2, sound='chime', + content_available=True, mutable_content=True, + category='my_category', url_args='args', custom={'extra': 'something'}, thread_id='42') + assert payload.dict() == { + 'aps': { + 'alert': 'my_alert', + 'badge': 2, + 'sound': 'chime', + 'content-available': 1, + 'mutable-content': 1, + 'thread-id': '42', + 'category': 'my_category', + 'url-args': 'args' + }, + 'extra': 'something' + } + + +def test_payload_with_payload_alert(payload_alert): + payload = Payload( + alert=payload_alert, badge=2, sound='chime', + content_available=True, mutable_content=True, + category='my_category', url_args='args', custom={'extra': 'something'}, thread_id='42') + assert payload.dict() == { + 'aps': { + 'alert': { + 'title': 'title', + 'title-loc-key': 'title_loc_k', + 'title-loc-args': ['title_loc_a'], + 'subtitle': 'subtitle', + 'subtitle-loc-key': 'subtitle_loc_k', + 'subtitle-loc-args': ['subtitle_loc_a'], + 'body': 'body', + 'loc-key': 'body_loc_k', + 'loc-args': ['body_loc_a'], + 'action-loc-key': 'ac_loc_k', + 'action': 'send', + 'launch-image': 'img' + }, + 'badge': 2, + 'sound': 'chime', + 'content-available': 1, + 'mutable-content': 1, + 'thread-id': '42', + 'category': 'my_category', + 'url-args': 'args', + }, + 'extra': 'something' + } +import pytest + +from apns2.payload import Payload, PayloadAlert + + @pytest.fixture def payload_alert(): return PayloadAlert(