diff --git a/apns2/payload.py b/apns2/payload.py index 52cdd52..54edcea 100644 --- a/apns2/payload.py +++ b/apns2/payload.py @@ -1,7 +1,9 @@ -from typing import Any, Dict, List, Optional, Union, Iterable +from typing import Any, Dict, List, Optional, Union, Iterable, Literal MAX_PAYLOAD_SIZE = 4096 +InterruptionLevelType = Literal['active', 'passive', 'time-sensitive', 'critical'] + class PayloadAlert(object): def __init__( @@ -79,6 +81,7 @@ def __init__( thread_id: Optional[str] = None, content_available: bool = False, mutable_content: bool = False, + interruption_level: Union[InterruptionLevelType, None] = None, ) -> None: self.alert = alert self.badge = badge @@ -89,6 +92,18 @@ def __init__( self.custom = custom self.mutable_content = mutable_content self.thread_id = thread_id + self.interruption_level = interruption_level + + @property + def interruption_level(self): + return self._interruption_level + + @interruption_level.setter + def interruption_level(self, value): + if value and value not in InterruptionLevelType.__args__: + raise Exception("-interruption_level- it must be at least a value of InterruptionLevelType or None: For further visit https://developer.apple.com/documentation/usernotifications/unnotificationinterruptionlevel. Valid values are: {} ".format(list(InterruptionLevelType.__args__))) + self._interruption_level = value + def dict(self) -> Dict[str, Any]: result = { @@ -114,6 +129,8 @@ def dict(self) -> Dict[str, Any]: result['aps']['category'] = self.category if self.url_args is not None: result['aps']['url-args'] = self.url_args + if self.interruption_level is not None: + result['aps']['interruption-level'] = self.interruption_level if self.custom is not None: result.update(self.custom) diff --git a/test/test_payload.py b/test/test_payload.py index c56b742..7337c37 100644 --- a/test/test_payload.py +++ b/test/test_payload.py @@ -60,9 +60,17 @@ def test_payload(): def test_payload_with_payload_alert(payload_alert): payload = Payload( - alert=payload_alert, badge=2, sound='chime', - content_available=True, mutable_content=True, - category='my_category', url_args='args', custom={'extra': 'something'}, thread_id='42') + alert=payload_alert, + badge=2, + sound="chime", + content_available=True, + mutable_content=True, + category="my_category", + url_args="args", + custom={"extra": "something"}, + thread_id="42", + interruption_level=None, + ) assert payload.dict() == { 'aps': { 'alert': { @@ -89,3 +97,38 @@ def test_payload_with_payload_alert(payload_alert): }, 'extra': 'something' } + + +@pytest.mark.parametrize("input_value", ["inactive", "invalid", "default"]) +def test_payload_alert_with_an_unvalid_interruption_level_value(payload_alert, input_value): + with pytest.raises(Exception): + Payload( + alert=payload_alert, + badge=2, + sound="chime", + content_available=True, + mutable_content=True, + category="my_category", + url_args="args", + custom={"extra": "something"}, + thread_id="42", + interruption_level=input_value, + ) + +@pytest.mark.parametrize("input_value", ["active", "passive", "time-sensitive", "critical"]) +def test_payload_alert_with_a_valid_interruption_level_value(payload_alert, input_value): + payload = Payload( + alert=payload_alert, + badge=2, + sound="chime", + content_available=True, + mutable_content=True, + category="my_category", + url_args="args", + custom={"extra": "something"}, + thread_id="42", + interruption_level=input_value, + ) + _payload = payload.dict() + assert "interruption-level" in _payload["aps"] + assert _payload["aps"]["interruption-level"] == input_value